forked from M-Labs/nac3
Compare commits
153 Commits
ndstrides-
...
master
Author | SHA1 | Date | |
---|---|---|---|
d394b24304 | |||
68da9b0ecf | |||
eec62c3bbb | |||
b521bc0c82 | |||
96e98947cc | |||
87a637b448 | |||
bdeeced122 | |||
37df08b803 | |||
05fd1a5199 | |||
f817d3347b | |||
2d275949b8 | |||
2783834cb1 | |||
879b063968 | |||
14e80dfab7 | |||
5fdbc34b43 | |||
32f24261f2 | |||
ce40a46f8a | |||
f15a64cc1b | |||
7fac801936 | |||
febfd1241d | |||
4bd5349381 | |||
c15062ab4c | |||
933804e270 | |||
1cfaa1a779 | |||
18e8e5269f | |||
357970a793 | |||
762a2447c3 | |||
8e614d83de | |||
bd66fe48d8 | |||
c59fd286ff | |||
f8530e0ef6 | |||
3ebd4ba5d1 | |||
d1dcfa19ff | |||
8baf111734 | |||
eaaa194a87 | |||
352c7c880b | |||
3c5e247195 | |||
4e21def1a0 | |||
2271b46b96 | |||
d9c180ed13 | |||
8322d457c6 | |||
e480081e4b | |||
12fddc3533 | |||
3ac1083734 | |||
66b8a5e01d | |||
ebbadc2d74 | |||
a2f1b25fd8 | |||
59f19e29df | |||
6cbba8fdde | |||
e6dab25a57 | |||
2dc5e79a23 | |||
dcde1d9c87 | |||
7375983e0c | |||
43e440d2fd | |||
8d975b5ff3 | |||
aae41eef6a | |||
132ba1942f | |||
12358c57b1 | |||
9ffa2d6552 | |||
acb437919d | |||
fadadd7505 | |||
26f1428739 | |||
5880f964bb | |||
7d02f5833d | |||
822f9d33f8 | |||
805a9d23b3 | |||
1ffe2fcc7f | |||
2f0847d77b | |||
dc9efa9e8c | |||
3c0ce3031f | |||
d5e8df070a | |||
dc413dfa43 | |||
19122e2905 | |||
318371a509 | |||
35e3042435 | |||
0e5940c49d | |||
fbf0053c24 | |||
456aefa6ee | |||
49a7469b4a | |||
1531b6cc98 | |||
9bbc40bbfa | |||
790e56d106 | |||
a00eb7969e | |||
27a6f47330 | |||
061747c67b | |||
dc91d9e35a | |||
438943ac6f | |||
678e56c95d | |||
fdfc80ca5f | |||
8b3429d62a | |||
f4c5038b95 | |||
ddd16738a6 | |||
44c49dc102 | |||
e4bd376587 | |||
44498f22f6 | |||
110416d07a | |||
08a7d01a13 | |||
3cd36fddc3 | |||
56a7a9e03d | |||
574ae40f97 | |||
aa293b6bea | |||
eb4b881690 | |||
3d0a1d281c | |||
ad67a99c8f | |||
8e2b50df21 | |||
06092ad29b | |||
d62c6b95fd | |||
95e29d9997 | |||
536ed2146c | |||
d484d44d95 | |||
ac978864f2 | |||
95254f8464 | |||
964945d244 | |||
ae09a0d444 | |||
01edd5af67 | |||
015714eee1 | |||
71dec251e3 | |||
fce61f7b8c | |||
babc081dbd | |||
5337dbe23b | |||
f862c01412 | |||
0c9705f5f1 | |||
5f940f86d9 | |||
5651e00688 | |||
f6745b987f | |||
e0dedc6580 | |||
28f574282c | |||
144f0922db | |||
c58ce9c3a9 | |||
f7e296da53 | |||
b58c99369e | |||
1a535db558 | |||
1ba2e287a6 | |||
f95f979ad3 | |||
48e2148c0f | |||
88e57f7120 | |||
d7633c42bc | |||
a4f53b6e6b | |||
9d9ead211e | |||
26a1b85206 | |||
2822074b2d | |||
fe67ed076c | |||
94e2414df0 | |||
2cee760404 | |||
230982dc84 | |||
2bd3f63991 | |||
b53266e9e6 | |||
86eb22bbf3 | |||
beaa38047d | |||
705dc4ff1c | |||
979209a526 | |||
c3927d0ef6 | |||
202a902cd0 |
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[target.x86_64-unknown-linux-gnu]
|
||||
rustflags = ["-C", "link-arg=-fuse-ld=lld"]
|
@ -1,7 +1,7 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
|
||||
default_stages: [commit]
|
||||
default_stages: [pre-commit]
|
||||
|
||||
repos:
|
||||
- repo: local
|
||||
|
511
Cargo.lock
generated
511
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@ -4,6 +4,7 @@ members = [
|
||||
"nac3ast",
|
||||
"nac3parser",
|
||||
"nac3core",
|
||||
"nac3core/nac3core_derive",
|
||||
"nac3standalone",
|
||||
"nac3artiq",
|
||||
"runkernel",
|
||||
|
6
flake.lock
generated
6
flake.lock
generated
@ -2,11 +2,11 @@
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1727348695,
|
||||
"narHash": "sha256-J+PeFKSDV+pHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI=",
|
||||
"lastModified": 1736798957,
|
||||
"narHash": "sha256-qwpCtZhSsSNQtK4xYGzMiyEDhkNzOCz/Vfu4oL2ETsQ=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "1925c603f17fc89f4c8f6bf6f631a802ad85d784",
|
||||
"rev": "9abb87b552b7f55ac8916b6fc9e5cb486656a2f3",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
19
flake.nix
19
flake.nix
@ -41,7 +41,7 @@
|
||||
lockFile = ./Cargo.lock;
|
||||
};
|
||||
passthru.cargoLock = cargoLock;
|
||||
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
|
||||
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out pkgs.llvmPackages_14.bintools llvm-nac3 ];
|
||||
buildInputs = [ pkgs.python3 llvm-nac3 ];
|
||||
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
|
||||
checkPhase =
|
||||
@ -85,7 +85,7 @@
|
||||
name = "nac3artiq-instrumented";
|
||||
src = self;
|
||||
inherit (nac3artiq) cargoLock;
|
||||
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ];
|
||||
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-instrumented ];
|
||||
buildInputs = [ pkgs.python3 llvm-nac3-instrumented ];
|
||||
cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ];
|
||||
doCheck = false;
|
||||
@ -107,19 +107,20 @@
|
||||
(pkgs.fetchFromGitHub {
|
||||
owner = "m-labs";
|
||||
repo = "sipyco";
|
||||
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
|
||||
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
|
||||
rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
|
||||
sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
|
||||
})
|
||||
(pkgs.fetchFromGitHub {
|
||||
owner = "m-labs";
|
||||
repo = "artiq";
|
||||
rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
|
||||
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
|
||||
rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
|
||||
sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
|
||||
})
|
||||
];
|
||||
buildInputs = [
|
||||
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
|
||||
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
|
||||
pkgs.llvmPackages_14.llvm.out
|
||||
pkgs.llvmPackages_14.bintools
|
||||
];
|
||||
phases = [ "buildPhase" "installPhase" ];
|
||||
buildPhase =
|
||||
@ -147,7 +148,7 @@
|
||||
name = "nac3artiq-pgo";
|
||||
src = self;
|
||||
inherit (nac3artiq) cargoLock;
|
||||
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ];
|
||||
nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt pkgs.llvmPackages_14.bintools llvm-nac3-pgo ];
|
||||
buildInputs = [ pkgs.python3 llvm-nac3-pgo ];
|
||||
cargoBuildFlags = [ "--package" "nac3artiq" ];
|
||||
cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ];
|
||||
@ -168,7 +169,7 @@
|
||||
buildInputs = with pkgs; [
|
||||
# build dependencies
|
||||
packages.x86_64-linux.llvm-nac3
|
||||
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
|
||||
(pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out llvmPackages_14.bintools # for running nac3standalone demos
|
||||
packages.x86_64-linux.llvm-tools-irrt
|
||||
cargo
|
||||
rustc
|
||||
|
@ -9,10 +9,10 @@ name = "nac3artiq"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.13"
|
||||
itertools = "0.14"
|
||||
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
|
||||
parking_lot = "0.12"
|
||||
tempfile = "3.13"
|
||||
tempfile = "3.16"
|
||||
nac3core = { path = "../nac3core" }
|
||||
nac3ld = { path = "../nac3ld" }
|
||||
|
||||
|
@ -1,66 +0,0 @@
|
||||
class EmbeddingMap:
|
||||
def __init__(self):
|
||||
self.object_inverse_map = {}
|
||||
self.object_map = {}
|
||||
self.string_map = {}
|
||||
self.string_reverse_map = {}
|
||||
self.function_map = {}
|
||||
self.attributes_writeback = []
|
||||
|
||||
# preallocate exception names
|
||||
self.preallocate_runtime_exception_names(["RuntimeError",
|
||||
"RTIOUnderflow",
|
||||
"RTIOOverflow",
|
||||
"RTIODestinationUnreachable",
|
||||
"DMAError",
|
||||
"I2CError",
|
||||
"CacheError",
|
||||
"SPIError",
|
||||
"0:ZeroDivisionError",
|
||||
"0:IndexError",
|
||||
"0:ValueError",
|
||||
"0:RuntimeError",
|
||||
"0:AssertionError",
|
||||
"0:KeyError",
|
||||
"0:NotImplementedError",
|
||||
"0:OverflowError",
|
||||
"0:IOError",
|
||||
"0:UnwrapNoneError"])
|
||||
|
||||
def preallocate_runtime_exception_names(self, names):
|
||||
for i, name in enumerate(names):
|
||||
if ":" not in name:
|
||||
name = "0:artiq.coredevice.exceptions." + name
|
||||
exn_id = self.store_str(name)
|
||||
assert exn_id == i
|
||||
|
||||
def store_function(self, key, fun):
|
||||
self.function_map[key] = fun
|
||||
return key
|
||||
|
||||
def store_object(self, obj):
|
||||
obj_id = id(obj)
|
||||
if obj_id in self.object_inverse_map:
|
||||
return self.object_inverse_map[obj_id]
|
||||
key = len(self.object_map) + 1
|
||||
self.object_map[key] = obj
|
||||
self.object_inverse_map[obj_id] = key
|
||||
return key
|
||||
|
||||
def store_str(self, s):
|
||||
if s in self.string_reverse_map:
|
||||
return self.string_reverse_map[s]
|
||||
key = len(self.string_map)
|
||||
self.string_map[key] = s
|
||||
self.string_reverse_map[s] = key
|
||||
return key
|
||||
|
||||
def retrieve_function(self, key):
|
||||
return self.function_map[key]
|
||||
|
||||
def retrieve_object(self, key):
|
||||
return self.object_map[key]
|
||||
|
||||
def retrieve_str(self, key):
|
||||
return self.string_map[key]
|
||||
|
@ -6,7 +6,6 @@ from typing import Generic, TypeVar
|
||||
from math import floor, ceil
|
||||
|
||||
import nac3artiq
|
||||
from embedding_map import EmbeddingMap
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -193,6 +192,46 @@ def print_int64(x: int64):
|
||||
raise NotImplementedError("syscall not simulated")
|
||||
|
||||
|
||||
class EmbeddingMap:
|
||||
def __init__(self):
|
||||
self.object_inverse_map = {}
|
||||
self.object_map = {}
|
||||
self.string_map = {}
|
||||
self.string_reverse_map = {}
|
||||
self.function_map = {}
|
||||
self.attributes_writeback = []
|
||||
|
||||
def store_function(self, key, fun):
|
||||
self.function_map[key] = fun
|
||||
return key
|
||||
|
||||
def store_object(self, obj):
|
||||
obj_id = id(obj)
|
||||
if obj_id in self.object_inverse_map:
|
||||
return self.object_inverse_map[obj_id]
|
||||
key = len(self.object_map) + 1
|
||||
self.object_map[key] = obj
|
||||
self.object_inverse_map[obj_id] = key
|
||||
return key
|
||||
|
||||
def store_str(self, s):
|
||||
if s in self.string_reverse_map:
|
||||
return self.string_reverse_map[s]
|
||||
key = len(self.string_map)
|
||||
self.string_map[key] = s
|
||||
self.string_reverse_map[s] = key
|
||||
return key
|
||||
|
||||
def retrieve_function(self, key):
|
||||
return self.function_map[key]
|
||||
|
||||
def retrieve_object(self, key):
|
||||
return self.object_map[key]
|
||||
|
||||
def retrieve_str(self, key):
|
||||
return self.string_map[key]
|
||||
|
||||
|
||||
@nac3
|
||||
class Core:
|
||||
ref_period: KernelInvariant[float]
|
||||
@ -206,7 +245,7 @@ class Core:
|
||||
embedding = EmbeddingMap()
|
||||
|
||||
if allow_registration:
|
||||
compiler.analyze(registered_functions, registered_classes)
|
||||
compiler.analyze(registered_functions, registered_classes, set())
|
||||
allow_registration = False
|
||||
|
||||
if hasattr(method, "__self__"):
|
||||
|
26
nac3artiq/demo/module.py
Normal file
26
nac3artiq/demo/module.py
Normal file
@ -0,0 +1,26 @@
|
||||
from min_artiq import *
|
||||
from numpy import int32
|
||||
|
||||
# Global Variable Definition
|
||||
X: Kernel[int32] = 1
|
||||
|
||||
# TopLevelFunction Defintion
|
||||
@kernel
|
||||
def display_X():
|
||||
print_int32(X)
|
||||
|
||||
# TopLevel Class Definition
|
||||
@nac3
|
||||
class A:
|
||||
@kernel
|
||||
def __init__(self):
|
||||
self.set_x(1)
|
||||
|
||||
@kernel
|
||||
def set_x(self, new_val: int32):
|
||||
global X
|
||||
X = new_val
|
||||
|
||||
@kernel
|
||||
def get_X(self) -> int32:
|
||||
return X
|
26
nac3artiq/demo/module_support.py
Normal file
26
nac3artiq/demo/module_support.py
Normal file
@ -0,0 +1,26 @@
|
||||
from min_artiq import *
|
||||
import module as module_definition
|
||||
|
||||
@nac3
|
||||
class TestModuleSupport:
|
||||
core: KernelInvariant[Core]
|
||||
|
||||
def __init__(self):
|
||||
self.core = Core()
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
# Accessing classes
|
||||
obj = module_definition.A()
|
||||
obj.get_X()
|
||||
obj.set_x(2)
|
||||
|
||||
# Calling functions
|
||||
module_definition.display_X()
|
||||
|
||||
# Updating global variables
|
||||
module_definition.X = 9
|
||||
module_definition.display_X()
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestModuleSupport().run()
|
29
nac3artiq/demo/numpy_primitives_decay.py
Normal file
29
nac3artiq/demo/numpy_primitives_decay.py
Normal file
@ -0,0 +1,29 @@
|
||||
from min_artiq import *
|
||||
import numpy
|
||||
from numpy import int32
|
||||
|
||||
|
||||
@nac3
|
||||
class NumpyBoolDecay:
|
||||
core: KernelInvariant[Core]
|
||||
np_true: KernelInvariant[bool]
|
||||
np_false: KernelInvariant[bool]
|
||||
np_int: KernelInvariant[int32]
|
||||
np_float: KernelInvariant[float]
|
||||
np_str: KernelInvariant[str]
|
||||
|
||||
def __init__(self):
|
||||
self.core = Core()
|
||||
self.np_true = numpy.True_
|
||||
self.np_false = numpy.False_
|
||||
self.np_int = numpy.int32(0)
|
||||
self.np_float = numpy.float64(0.0)
|
||||
self.np_str = numpy.str_("")
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NumpyBoolDecay().run()
|
@ -1,24 +0,0 @@
|
||||
from min_artiq import *
|
||||
from numpy import int32
|
||||
|
||||
|
||||
@nac3
|
||||
class Demo:
|
||||
core: KernelInvariant[Core]
|
||||
attr1: KernelInvariant[str]
|
||||
attr2: KernelInvariant[int32]
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.core = Core()
|
||||
self.attr2 = 32
|
||||
self.attr1 = "SAMPLE"
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
print_int32(self.attr2)
|
||||
self.attr1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Demo().run()
|
@ -1,40 +0,0 @@
|
||||
from min_artiq import *
|
||||
from numpy import int32
|
||||
|
||||
|
||||
@nac3
|
||||
class Demo:
|
||||
attr1: KernelInvariant[int32] = 2
|
||||
attr2: int32 = 4
|
||||
attr3: Kernel[int32]
|
||||
|
||||
@kernel
|
||||
def __init__(self):
|
||||
self.attr3 = 8
|
||||
|
||||
|
||||
@nac3
|
||||
class NAC3Devices:
|
||||
core: KernelInvariant[Core]
|
||||
attr4: KernelInvariant[int32] = 16
|
||||
|
||||
def __init__(self):
|
||||
self.core = Core()
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
Demo.attr1 # Supported
|
||||
# Demo.attr2 # Field not accessible on Kernel
|
||||
# Demo.attr3 # Only attributes can be accessed in this way
|
||||
# Demo.attr1 = 2 # Attributes are immutable
|
||||
|
||||
self.attr4 # Attributes can be accessed within class
|
||||
|
||||
obj = Demo()
|
||||
obj.attr1 # Attributes can be accessed by class objects
|
||||
|
||||
NAC3Devices.attr4 # Attributes accessible for classes without __init__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
NAC3Devices().run()
|
@ -12,33 +12,38 @@ use pyo3::{
|
||||
PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
use nac3core::{
|
||||
codegen::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
|
||||
NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
|
||||
},
|
||||
expr::{destructure_range, gen_call},
|
||||
irrt::call_ndarray_calc_size,
|
||||
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
||||
llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave},
|
||||
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
||||
type_aligned_alloca,
|
||||
types::{ndarray::NDArrayType, RangeType},
|
||||
values::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
|
||||
UntypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
inkwell::{
|
||||
context::Context,
|
||||
module::Linkage,
|
||||
targets::TargetMachine,
|
||||
types::{BasicType, IntType},
|
||||
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
},
|
||||
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
|
||||
toplevel::{
|
||||
helper::{extract_ndims, PrimDef},
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
DefinitionId, GenCall,
|
||||
},
|
||||
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
|
||||
};
|
||||
|
||||
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
/// The parallelism mode within a block.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum ParallelMode {
|
||||
@ -83,13 +88,13 @@ pub struct ArtiqCodeGenerator<'a> {
|
||||
impl<'a> ArtiqCodeGenerator<'a> {
|
||||
pub fn new(
|
||||
name: String,
|
||||
size_t: u32,
|
||||
size_t: IntType<'_>,
|
||||
timeline: &'a (dyn TimeFns + Sync),
|
||||
) -> ArtiqCodeGenerator<'a> {
|
||||
assert!(size_t == 32 || size_t == 64);
|
||||
assert!(matches!(size_t.get_bit_width(), 32 | 64));
|
||||
ArtiqCodeGenerator {
|
||||
name,
|
||||
size_t,
|
||||
size_t: size_t.get_bit_width(),
|
||||
name_counter: 0,
|
||||
start: None,
|
||||
end: None,
|
||||
@ -98,6 +103,17 @@ impl<'a> ArtiqCodeGenerator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_target_machine(
|
||||
name: String,
|
||||
ctx: &Context,
|
||||
target_machine: &TargetMachine,
|
||||
timeline: &'a (dyn TimeFns + Sync),
|
||||
) -> ArtiqCodeGenerator<'a> {
|
||||
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
|
||||
Self::new(name, llvm_usize, timeline)
|
||||
}
|
||||
|
||||
/// If the generator is currently in a direct-`parallel` block context, emits IR that resets the
|
||||
/// position of the timeline to the initial timeline position before entering the `parallel`
|
||||
/// block.
|
||||
@ -158,7 +174,7 @@ impl<'a> ArtiqCodeGenerator<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
impl CodeGenerator for ArtiqCodeGenerator<'_> {
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
@ -455,54 +471,51 @@ fn format_rpc_arg<'ctx>(
|
||||
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
|
||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
let dtype = ctx.get_llvm_type(generator, elem_ty);
|
||||
let ndarray = NDArrayType::new(ctx, dtype, ndims)
|
||||
.map_pointer_value(arg.into_pointer_value(), None);
|
||||
|
||||
let llvm_usize_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let llvm_pdata_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
||||
llvm_usize,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let ndims = llvm_usize.const_int(ndims, false);
|
||||
|
||||
let dims_buf_sz =
|
||||
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||
// `ndarray.data` is possibly not contiguous, and we need it to be contiguous for
|
||||
// the reader.
|
||||
// Turning it into a ContiguousNDArray to get a `data` that is contiguous.
|
||||
let carray = ndarray.make_contiguous_ndarray(generator, ctx);
|
||||
|
||||
let buffer_size =
|
||||
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
||||
let sizeof_usize = llvm_usize.size_of();
|
||||
let sizeof_usize =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_usize, llvm_usize, "").unwrap();
|
||||
|
||||
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
||||
let sizeof_pdata = dtype.ptr_type(AddressSpace::default()).size_of();
|
||||
let sizeof_pdata =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_pdata, llvm_usize, "").unwrap();
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
buffer.base_ptr(ctx, generator),
|
||||
llvm_arg.ptr_to_data(ctx),
|
||||
llvm_pdata_sizeof,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
let sizeof_buf_shape = ctx.builder.build_int_mul(sizeof_usize, ndims, "").unwrap();
|
||||
let sizeof_buf = ctx.builder.build_int_add(sizeof_buf_shape, sizeof_pdata, "").unwrap();
|
||||
|
||||
let pbuffer_dims_begin =
|
||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
pbuffer_dims_begin,
|
||||
llvm_arg.dim_sizes().base_ptr(ctx, generator),
|
||||
dims_buf_sz,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
// buf = { data: void*, shape: [size_t; ndims]; }
|
||||
let buf = ctx.builder.build_array_alloca(llvm_i8, sizeof_buf, "rpc.arg").unwrap();
|
||||
let buf = ArraySliceValue::from_ptr_val(buf, sizeof_buf, Some("rpc.arg"));
|
||||
let buf_data = buf.base_ptr(ctx, generator);
|
||||
let buf_shape =
|
||||
unsafe { buf.ptr_offset_unchecked(ctx, generator, &sizeof_pdata, None) };
|
||||
|
||||
buffer.base_ptr(ctx, generator)
|
||||
// Write to `buf->data`
|
||||
let carray_data = carray.load_data(ctx);
|
||||
let carray_data = ctx.builder.build_pointer_cast(carray_data, llvm_pi8, "").unwrap();
|
||||
call_memcpy(ctx, buf_data, carray_data, sizeof_pdata, llvm_i1.const_zero());
|
||||
|
||||
// Write to `buf->shape`
|
||||
let carray_shape = ndarray.shape().base_ptr(ctx, generator);
|
||||
let carray_shape_i8 =
|
||||
ctx.builder.build_pointer_cast(carray_shape, llvm_pi8, "").unwrap();
|
||||
call_memcpy(ctx, buf_shape, carray_shape_i8, sizeof_buf_shape, llvm_i1.const_zero());
|
||||
|
||||
buf.base_ptr(ctx, generator)
|
||||
}
|
||||
|
||||
_ => {
|
||||
@ -543,6 +556,8 @@ fn format_rpc_ret<'ctx>(
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
|
||||
@ -563,8 +578,7 @@ fn format_rpc_ret<'ctx>(
|
||||
|
||||
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let num_0 = llvm_usize.const_zero();
|
||||
|
||||
// Round `val` up to its modulo `power_of_two`
|
||||
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -590,79 +604,49 @@ fn format_rpc_ret<'ctx>(
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
// Setup types
|
||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||
|
||||
// Allocate the resulting ndarray
|
||||
// A condition after format_rpc_ret ensures this will not be popped this off.
|
||||
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
||||
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
let ndarray = NDArrayType::new(ctx, dtype_llvm, ndims)
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
|
||||
// Setup ndims
|
||||
let ndims =
|
||||
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
|
||||
assert_eq!(values.len(), 1);
|
||||
// NOTE: Current content of `ndarray`:
|
||||
// - * `data` - **NOT YET** allocated.
|
||||
// - * `itemsize` - initialized to be size_of(dtype).
|
||||
// - * `ndims` - initialized.
|
||||
// - * `shape` - allocated; has uninitialized values.
|
||||
// - * `strides` - allocated; has uninitialized values.
|
||||
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
// Set `ndarray.ndims`
|
||||
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
||||
// Allocate `ndarray.shape` [size_t; ndims]
|
||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
||||
|
||||
/*
|
||||
ndarray now:
|
||||
- .ndims: initialized
|
||||
- .shape: allocated but uninitialized .shape
|
||||
- .data: uninitialized
|
||||
*/
|
||||
|
||||
let llvm_usize_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let llvm_pdata_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
||||
llvm_usize,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let llvm_elem_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let itemsize = ndarray.load_itemsize(ctx); // Same as doing a `ctx.get_llvm_type` on `dtype` and get its `size_of()`.
|
||||
|
||||
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
|
||||
// (4 + 4 * ndims) bytes with 8-byte alignment
|
||||
let sizeof_dims =
|
||||
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||
let sizeof_usize = llvm_usize.size_of();
|
||||
let sizeof_usize =
|
||||
ctx.builder.build_int_truncate_or_bit_cast(sizeof_usize, llvm_usize, "").unwrap();
|
||||
|
||||
let sizeof_ptr = llvm_i8.ptr_type(AddressSpace::default()).size_of();
|
||||
let sizeof_ptr =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(sizeof_ptr, llvm_usize, "").unwrap();
|
||||
|
||||
let sizeof_shape =
|
||||
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), sizeof_usize, "").unwrap();
|
||||
|
||||
// Size of the buffer for the initial `rpc_recv()`.
|
||||
let unaligned_buffer_size =
|
||||
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
||||
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
|
||||
ctx.builder.build_int_add(sizeof_ptr, sizeof_shape, "").unwrap();
|
||||
|
||||
let stackptr = call_stacksave(ctx, None);
|
||||
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
|
||||
let buffer = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
llvm_i8_8,
|
||||
ctx.builder
|
||||
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
|
||||
.unwrap(),
|
||||
"rpc.buffer",
|
||||
)
|
||||
.unwrap();
|
||||
let buffer = ctx
|
||||
.builder
|
||||
.build_bit_cast(buffer, llvm_pi8, "")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
||||
let buffer = type_aligned_alloca(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_i8_8,
|
||||
unaligned_buffer_size,
|
||||
Some("rpc.buffer"),
|
||||
);
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, unaligned_buffer_size, None);
|
||||
|
||||
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
||||
//
|
||||
@ -670,7 +654,7 @@ fn format_rpc_ret<'ctx>(
|
||||
let ndarray_nbytes = ctx
|
||||
.build_call_or_invoke(
|
||||
rpc_recv,
|
||||
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
|
||||
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]
|
||||
"rpc.size.next",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
@ -678,16 +662,14 @@ fn format_rpc_ret<'ctx>(
|
||||
|
||||
// debug_assert(ndarray_nbytes > 0)
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let cmp = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::UGT, ndarray_nbytes, num_0, "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::UGT,
|
||||
ndarray_nbytes,
|
||||
ndarray_nbytes.get_type().const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
cmp,
|
||||
"0:AssertionError",
|
||||
"Unexpected RPC termination for ndarray - Expected data buffer next",
|
||||
[None, None, None],
|
||||
@ -696,49 +678,50 @@ fn format_rpc_ret<'ctx>(
|
||||
}
|
||||
|
||||
// Copy shape from the buffer to `ndarray.shape`.
|
||||
let pbuffer_dims =
|
||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||
// We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`.
|
||||
let pbuffer_shape =
|
||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &sizeof_ptr, None) };
|
||||
let pbuffer_shape =
|
||||
ctx.builder.build_pointer_cast(pbuffer_shape, llvm_pusize, "").unwrap();
|
||||
|
||||
// Copy shape from buffer to `ndarray.shape`
|
||||
ndarray.copy_shape_from_array(generator, ctx, pbuffer_shape);
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
||||
pbuffer_dims,
|
||||
sizeof_dims,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
// Restore stack from before allocation of buffer
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
// Allocate `ndarray.data`.
|
||||
// `ndarray.shape` must be initialized beforehand in this implementation
|
||||
// (for ndarray.create_data() to know how many elements to allocate)
|
||||
let num_elements =
|
||||
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
|
||||
unsafe { ndarray.create_data(generator, ctx) }; // NOTE: the strides of `ndarray` has also been set to contiguous in `create_data`.
|
||||
|
||||
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let sizeof_data =
|
||||
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
|
||||
let num_elements = ndarray.size(ctx);
|
||||
|
||||
let expected_ndarray_nbytes =
|
||||
ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap();
|
||||
let cmp = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::UGE,
|
||||
expected_ndarray_nbytes,
|
||||
ndarray_nbytes,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::UGE,
|
||||
sizeof_data,
|
||||
ndarray_nbytes,
|
||||
"",
|
||||
).unwrap(),
|
||||
cmp,
|
||||
"0:AssertionError",
|
||||
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
|
||||
[Some(sizeof_data), Some(ndarray_nbytes), None],
|
||||
[Some(expected_ndarray_nbytes), Some(ndarray_nbytes), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
|
||||
|
||||
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
||||
let ndarray_data_i8 =
|
||||
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
||||
|
||||
// NOTE: Currently on `prehead_bb`
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
@ -747,7 +730,7 @@ fn format_rpc_ret<'ctx>(
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
|
||||
phi.add_incoming(&[(&ndarray_data, prehead_bb)]);
|
||||
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
@ -762,12 +745,13 @@ fn format_rpc_ret<'ctx>(
|
||||
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
// Align the allocation to sizeof(T)
|
||||
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
|
||||
let alloc_size = round_up(ctx, alloc_size, itemsize);
|
||||
// TODO(Derppening): Candidate for refactor into type_aligned_alloca
|
||||
let alloc_ptr = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
llvm_elem_ty,
|
||||
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
|
||||
dtype_llvm,
|
||||
ctx.builder.build_int_unsigned_div(alloc_size, itemsize, "").unwrap(),
|
||||
"rpc.alloc",
|
||||
)
|
||||
.unwrap();
|
||||
@ -777,7 +761,7 @@ fn format_rpc_ret<'ctx>(
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
ndarray.as_base_value().into()
|
||||
ndarray.as_abi_value(ctx).into()
|
||||
}
|
||||
|
||||
_ => {
|
||||
@ -825,7 +809,7 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let int8 = ctx.ctx.i8_type();
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let size_type = ctx.get_size_type();
|
||||
let ptr_type = int8.ptr_type(AddressSpace::default());
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
@ -990,11 +974,12 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attributes_writeback(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
pub fn attributes_writeback<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
inner_resolver: &InnerResolver,
|
||||
host_attributes: &PyObject,
|
||||
return_obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
) -> Result<(), String> {
|
||||
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
||||
let host_attributes: &PyList = host_attributes.downcast(py)?;
|
||||
@ -1004,6 +989,11 @@ pub fn attributes_writeback(
|
||||
let zero = int32.const_zero();
|
||||
let mut values = Vec::new();
|
||||
let mut scratch_buffer = Vec::new();
|
||||
|
||||
if let Some((ty, obj)) = return_obj {
|
||||
values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap()));
|
||||
}
|
||||
|
||||
for val in (*globals).values() {
|
||||
let val = val.as_ref(py);
|
||||
let ty = inner_resolver.get_obj_type(
|
||||
@ -1062,6 +1052,34 @@ pub fn attributes_writeback(
|
||||
));
|
||||
}
|
||||
}
|
||||
TypeEnum::TModule { attributes, .. } => {
|
||||
let mut fields = Vec::new();
|
||||
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
|
||||
|
||||
for (name, (field_ty, is_method)) in attributes {
|
||||
if *is_method {
|
||||
continue;
|
||||
}
|
||||
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
||||
fields.push(name.to_string());
|
||||
let (index, _) = ctx.get_attr_index(ty, *name);
|
||||
values.push((
|
||||
*field_ty,
|
||||
ctx.build_gep_and_load(
|
||||
obj.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)],
|
||||
None,
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
if !fields.is_empty() {
|
||||
let pydict = PyDict::new(py);
|
||||
pydict.set_item("obj", val)?;
|
||||
pydict.set_item("fields", fields)?;
|
||||
host_attributes.append(pydict)?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@ -1082,7 +1100,7 @@ pub fn attributes_writeback(
|
||||
let args: Vec<_> =
|
||||
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
||||
if let Err(e) =
|
||||
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, false)
|
||||
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, true)
|
||||
{
|
||||
return Ok(Err(e));
|
||||
}
|
||||
@ -1177,7 +1195,7 @@ fn polymorphic_print<'ctx>(
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
let suffix = suffix.unwrap_or_default();
|
||||
|
||||
@ -1308,7 +1326,8 @@ fn polymorphic_print<'ctx>(
|
||||
fmt.push('[');
|
||||
flush(ctx, generator, &mut fmt, &mut args);
|
||||
|
||||
let val = ListValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None);
|
||||
let val =
|
||||
ListValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
|
||||
let len = val.load_size(ctx, None);
|
||||
let last =
|
||||
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
||||
@ -1359,56 +1378,50 @@ fn polymorphic_print<'ctx>(
|
||||
}
|
||||
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
fmt.push_str("array([");
|
||||
flush(ctx, generator, &mut fmt, &mut args);
|
||||
|
||||
let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None);
|
||||
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None));
|
||||
let last =
|
||||
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
let ndarray = NDArrayType::from_unifier_type(generator, ctx, ty)
|
||||
.map_pointer_value(value.into_pointer_value(), None);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(len, false),
|
||||
|generator, ctx, _, i| {
|
||||
let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) };
|
||||
let num_0 = llvm_usize.const_zero();
|
||||
|
||||
polymorphic_print(
|
||||
ctx,
|
||||
generator,
|
||||
&[(elem_ty, elem.into())],
|
||||
"",
|
||||
None,
|
||||
true,
|
||||
as_rtio,
|
||||
)?;
|
||||
// Print `ndarray` as a flat list delimited by interspersed with ", \0"
|
||||
ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| {
|
||||
let i = hdl.get_index(ctx);
|
||||
let scalar = hdl.get_scalar(ctx);
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::ULT, i, last, "")
|
||||
.unwrap())
|
||||
},
|
||||
|generator, ctx| {
|
||||
printf(ctx, generator, ", \0".into(), Vec::default());
|
||||
// if (i != 0) puts(", ");
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
let not_first = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, i, num_0, "")
|
||||
.unwrap();
|
||||
Ok(not_first)
|
||||
},
|
||||
|generator, ctx| {
|
||||
printf(ctx, generator, ", \0".into(), Vec::default());
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
// Print element
|
||||
polymorphic_print(
|
||||
ctx,
|
||||
generator,
|
||||
&[(dtype, scalar.into())],
|
||||
"",
|
||||
None,
|
||||
true,
|
||||
as_rtio,
|
||||
)?;
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
fmt.push_str(")]");
|
||||
flush(ctx, generator, &mut fmt, &mut args);
|
||||
@ -1418,7 +1431,7 @@ fn polymorphic_print<'ctx>(
|
||||
fmt.push_str("range(");
|
||||
flush(ctx, generator, &mut fmt, &mut args);
|
||||
|
||||
let val = RangeValue::from_ptr_val(value.into_pointer_value(), None);
|
||||
let val = RangeType::new(ctx).map_pointer_value(value.into_pointer_value(), None);
|
||||
|
||||
let (start, stop, step) = destructure_range(ctx, val);
|
||||
|
||||
@ -1532,7 +1545,7 @@ pub fn call_rtio_log_impl<'ctx>(
|
||||
/// Generates a call to `core_log`.
|
||||
pub fn gen_core_log<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
obj: Option<&(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
@ -1549,7 +1562,7 @@ pub fn gen_core_log<'ctx>(
|
||||
/// Generates a call to `rtio_log`.
|
||||
pub fn gen_rtio_log<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: &Option<(Type, ValueEnum<'ctx>)>,
|
||||
obj: Option<&(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
|
@ -1,10 +1,4 @@
|
||||
#![deny(
|
||||
future_incompatible,
|
||||
let_underscore,
|
||||
nonstandard_style,
|
||||
rust_2024_compatibility,
|
||||
clippy::all
|
||||
)]
|
||||
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
unsafe_op_in_unsafe_fn,
|
||||
@ -30,26 +24,26 @@ use parking_lot::{Mutex, RwLock};
|
||||
use pyo3::{
|
||||
create_exception, exceptions,
|
||||
prelude::*,
|
||||
types::{PyBytes, PyDict, PySet},
|
||||
types::{PyBytes, PyDict, PyNone, PySet},
|
||||
};
|
||||
use tempfile::{self, TempDir};
|
||||
|
||||
use nac3core::{
|
||||
codegen::{
|
||||
concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
|
||||
CodeGenTargetMachineOptions, CodeGenTask, WithCall, WorkerRegistry,
|
||||
CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry,
|
||||
},
|
||||
inkwell::{
|
||||
context::Context,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::{Linkage, Module},
|
||||
module::{FlagBehavior, Linkage, Module},
|
||||
passes::PassBuilderOptions,
|
||||
support::is_multithreaded,
|
||||
targets::*,
|
||||
OptimizationLevel,
|
||||
},
|
||||
nac3parser::{
|
||||
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
|
||||
ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
|
||||
parser::parse_program,
|
||||
},
|
||||
symbol_resolver::SymbolResolver,
|
||||
@ -84,14 +78,62 @@ enum Isa {
|
||||
}
|
||||
|
||||
impl Isa {
|
||||
/// Returns the number of bits in `size_t` for the [`Isa`].
|
||||
fn get_size_type(self) -> u32 {
|
||||
if self == Isa::Host {
|
||||
64u32
|
||||
} else {
|
||||
32u32
|
||||
/// Returns the [`TargetTriple`] used for compiling to this ISA.
|
||||
pub fn get_llvm_target_triple(self) -> TargetTriple {
|
||||
match self {
|
||||
Isa::Host => TargetMachine::get_default_triple(),
|
||||
Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"),
|
||||
Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`String`] representing the target CPU used for compiling to this ISA.
|
||||
pub fn get_llvm_target_cpu(self) -> String {
|
||||
match self {
|
||||
Isa::Host => TargetMachine::get_host_cpu_name().to_string(),
|
||||
Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(),
|
||||
Isa::CortexA9 => "cortex-a9".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`String`] representing the target features used for compiling to this ISA.
|
||||
pub fn get_llvm_target_features(self) -> String {
|
||||
match self {
|
||||
Isa::Host => TargetMachine::get_host_cpu_features().to_string(),
|
||||
Isa::RiscV32G => "+a,+m,+f,+d".to_string(),
|
||||
Isa::RiscV32IMA => "+a,+m".to_string(),
|
||||
Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine
|
||||
/// options used for compiling to this ISA.
|
||||
pub fn get_llvm_target_options(self) -> CodeGenTargetMachineOptions {
|
||||
CodeGenTargetMachineOptions {
|
||||
triple: self.get_llvm_target_triple().as_str().to_string_lossy().into_owned(),
|
||||
cpu: self.get_llvm_target_cpu(),
|
||||
features: self.get_llvm_target_features(),
|
||||
reloc_mode: RelocMode::PIC,
|
||||
..CodeGenTargetMachineOptions::from_host()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an instance of [`TargetMachine`] used in compiling and linking of a program of this
|
||||
/// ISA.
|
||||
pub fn create_llvm_target_machine(self, opt_level: OptimizationLevel) -> TargetMachine {
|
||||
self.get_llvm_target_options()
|
||||
.create_target_machine(opt_level)
|
||||
.expect("couldn't create target machine")
|
||||
}
|
||||
|
||||
/// Returns the number of bits in `size_t` for this ISA.
|
||||
fn get_size_type(self, ctx: &Context) -> u32 {
|
||||
ctx.ptr_sized_int_type(
|
||||
&self.create_llvm_target_machine(OptimizationLevel::Default).get_target_data(),
|
||||
None,
|
||||
)
|
||||
.get_bit_width()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -117,6 +159,7 @@ pub struct PrimitivePythonId {
|
||||
generic_alias: (u64, u64),
|
||||
virtual_id: u64,
|
||||
option: u64,
|
||||
module: u64,
|
||||
}
|
||||
|
||||
type TopLevelComponent = (Stmt, String, PyObject);
|
||||
@ -148,14 +191,32 @@ impl Nac3 {
|
||||
module: &PyObject,
|
||||
registered_class_ids: &HashSet<u64>,
|
||||
) -> PyResult<()> {
|
||||
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> {
|
||||
let module: &PyAny = module.extract(py)?;
|
||||
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?))
|
||||
})?;
|
||||
let (module_name, source_file, source) =
|
||||
Python::with_gil(|py| -> PyResult<(String, String, String)> {
|
||||
let module: &PyAny = module.extract(py)?;
|
||||
let source_file = module.getattr("__file__");
|
||||
let (source_file, source) = if let Ok(source_file) = source_file {
|
||||
let source_file = source_file.extract()?;
|
||||
(
|
||||
source_file,
|
||||
fs::read_to_string(source_file).map_err(|e| {
|
||||
exceptions::PyIOError::new_err(format!(
|
||||
"failed to read input file: {e}"
|
||||
))
|
||||
})?,
|
||||
)
|
||||
} else {
|
||||
// kernels submitted by content have no file
|
||||
// but still can provide source by StringLoader
|
||||
let get_src_fn = module
|
||||
.getattr("__loader__")?
|
||||
.extract::<PyObject>()?
|
||||
.getattr(py, "get_source")?;
|
||||
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
|
||||
};
|
||||
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
|
||||
})?;
|
||||
|
||||
let source = fs::read_to_string(&source_file).map_err(|e| {
|
||||
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
|
||||
})?;
|
||||
let parser_result = parse_program(&source, source_file.into())
|
||||
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
||||
|
||||
@ -216,6 +277,10 @@ impl Nac3 {
|
||||
}
|
||||
})
|
||||
}
|
||||
// Allow global variable declaration with `Kernel` type annotation
|
||||
StmtKind::AnnAssign { ref annotation, .. } => {
|
||||
matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into()))
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
|
||||
@ -318,7 +383,7 @@ impl Nac3 {
|
||||
vars: into_var_map([arg_ty]),
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
||||
gen_core_log(ctx, &obj, fun, &args, generator)?;
|
||||
gen_core_log(ctx, obj.as_ref(), fun, &args, generator)?;
|
||||
|
||||
Ok(None)
|
||||
}))),
|
||||
@ -348,7 +413,7 @@ impl Nac3 {
|
||||
vars: into_var_map([arg_ty]),
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
||||
gen_rtio_log(ctx, &obj, fun, &args, generator)?;
|
||||
gen_rtio_log(ctx, obj.as_ref(), fun, &args, generator)?;
|
||||
|
||||
Ok(None)
|
||||
}))),
|
||||
@ -366,7 +431,7 @@ impl Nac3 {
|
||||
py: Python,
|
||||
link_fn: &dyn Fn(&Module) -> PyResult<T>,
|
||||
) -> PyResult<T> {
|
||||
let size_t = self.isa.get_size_type();
|
||||
let size_t = self.isa.get_size_type(&Context::create());
|
||||
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
|
||||
self.builtins.clone(),
|
||||
Self::get_lateinit_builtins(),
|
||||
@ -409,12 +474,14 @@ impl Nac3 {
|
||||
];
|
||||
add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names);
|
||||
|
||||
// Stores a mapping from module id to attributes
|
||||
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
|
||||
|
||||
let mut rpc_ids = vec![];
|
||||
for (stmt, path, module) in &self.top_levels {
|
||||
let py_module: &PyAny = module.extract(py)?;
|
||||
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
|
||||
let module_name: String = py_module.getattr("__name__")?.extract()?;
|
||||
let helper = helper.clone();
|
||||
let class_obj;
|
||||
if let StmtKind::ClassDef { name, .. } = &stmt.node {
|
||||
@ -429,7 +496,7 @@ impl Nac3 {
|
||||
} else {
|
||||
class_obj = None;
|
||||
}
|
||||
let (name_to_pyid, resolver) =
|
||||
let (name_to_pyid, resolver, _, _) =
|
||||
module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| {
|
||||
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
||||
let members: &PyDict =
|
||||
@ -458,9 +525,17 @@ impl Nac3 {
|
||||
})))
|
||||
as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
let name_to_pyid = Rc::new(name_to_pyid);
|
||||
module_to_resolver_cache
|
||||
.insert(module_id, (name_to_pyid.clone(), resolver.clone()));
|
||||
(name_to_pyid, resolver)
|
||||
let module_location = ast::Location::new(1, 1, stmt.location.file);
|
||||
module_to_resolver_cache.insert(
|
||||
module_id,
|
||||
(
|
||||
name_to_pyid.clone(),
|
||||
resolver.clone(),
|
||||
module_name.clone(),
|
||||
Some(module_location),
|
||||
),
|
||||
);
|
||||
(name_to_pyid, resolver, module_name, Some(module_location))
|
||||
});
|
||||
|
||||
let (name, def_id, ty) = composer
|
||||
@ -534,6 +609,24 @@ impl Nac3 {
|
||||
}
|
||||
}
|
||||
|
||||
// Adding top level module definitions
|
||||
for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in
|
||||
module_to_resolver_cache
|
||||
{
|
||||
let def_id = composer
|
||||
.register_top_level_module(
|
||||
&module_name,
|
||||
&module_name_to_pyid,
|
||||
module_resolver,
|
||||
module_location,
|
||||
)
|
||||
.map_err(|e| {
|
||||
CompileError::new_err(format!("compilation failed\n----------\n{e}"))
|
||||
})?;
|
||||
|
||||
self.pyid_to_def.write().insert(module_id, def_id);
|
||||
}
|
||||
|
||||
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
|
||||
let module = PyModule::new(py, "tmp")?;
|
||||
@ -565,7 +658,7 @@ impl Nac3 {
|
||||
field_to_val: RwLock::default(),
|
||||
name_to_pyid,
|
||||
module: module.to_object(py),
|
||||
helper,
|
||||
helper: helper.clone(),
|
||||
string_store: self.string_store.clone(),
|
||||
exception_ids: self.exception_ids.clone(),
|
||||
deferred_eval_store: self.deferred_eval_store.clone(),
|
||||
@ -653,6 +746,9 @@ impl Nac3 {
|
||||
"Unsupported @rpc annotation on global variable",
|
||||
)))
|
||||
}
|
||||
TopLevelDef::Module { .. } => {
|
||||
unreachable!("Type module cannot be decorated with @rpc")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -673,33 +769,12 @@ impl Nac3 {
|
||||
let task = CodeGenTask {
|
||||
subst: Vec::default(),
|
||||
symbol_name: "__modinit__".to_string(),
|
||||
body: instance.body,
|
||||
signature,
|
||||
resolver: resolver.clone(),
|
||||
store,
|
||||
unifier_index: instance.unifier_id,
|
||||
calls: instance.calls,
|
||||
id: 0,
|
||||
};
|
||||
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
let mut cache = HashMap::new();
|
||||
let signature = store.from_signature(
|
||||
&mut composer.unifier,
|
||||
&self.primitive,
|
||||
&fun_signature,
|
||||
&mut cache,
|
||||
);
|
||||
let signature = store.add_cty(signature);
|
||||
let attributes_writeback_task = CodeGenTask {
|
||||
subst: Vec::default(),
|
||||
symbol_name: "attributes_writeback".to_string(),
|
||||
body: Arc::new(Vec::default()),
|
||||
signature,
|
||||
resolver,
|
||||
store,
|
||||
unifier_index: instance.unifier_id,
|
||||
calls: Arc::new(HashMap::default()),
|
||||
calls: instance.calls,
|
||||
id: 0,
|
||||
};
|
||||
|
||||
@ -712,30 +787,47 @@ impl Nac3 {
|
||||
let buffer = buffer.as_slice().into();
|
||||
membuffer.lock().push(buffer);
|
||||
})));
|
||||
let size_t = context
|
||||
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
|
||||
.get_bit_width();
|
||||
let num_threads = if is_multithreaded() { 4 } else { 1 };
|
||||
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
|
||||
let threads: Vec<_> = thread_names
|
||||
.iter()
|
||||
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
|
||||
.map(|s| {
|
||||
Box::new(ArtiqCodeGenerator::with_target_machine(
|
||||
s.to_string(),
|
||||
&context,
|
||||
&self.get_llvm_target_machine(),
|
||||
self.time_fns,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let membuffer = membuffers.clone();
|
||||
let mut has_return = false;
|
||||
py.allow_threads(|| {
|
||||
let (registry, handles) =
|
||||
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
|
||||
registry.add_task(task);
|
||||
registry.wait_tasks_complete(handles);
|
||||
|
||||
let mut generator =
|
||||
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
|
||||
let context = Context::create();
|
||||
let module = context.create_module("attributes_writeback");
|
||||
let mut generator = ArtiqCodeGenerator::with_target_machine(
|
||||
"main".to_string(),
|
||||
&context,
|
||||
&self.get_llvm_target_machine(),
|
||||
self.time_fns,
|
||||
);
|
||||
let module = context.create_module("main");
|
||||
let target_machine = self.llvm_options.create_target_machine().unwrap();
|
||||
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
|
||||
module.set_triple(&target_machine.get_triple());
|
||||
module.add_basic_value_flag(
|
||||
"Debug Info Version",
|
||||
FlagBehavior::Warning,
|
||||
context.i32_type().const_int(3, false),
|
||||
);
|
||||
module.add_basic_value_flag(
|
||||
"Dwarf Version",
|
||||
FlagBehavior::Warning,
|
||||
context.i32_type().const_int(4, false),
|
||||
);
|
||||
let builder = context.create_builder();
|
||||
let (_, module, _) = gen_func_impl(
|
||||
&context,
|
||||
@ -743,9 +835,27 @@ impl Nac3 {
|
||||
®istry,
|
||||
builder,
|
||||
module,
|
||||
attributes_writeback_task,
|
||||
task,
|
||||
|generator, ctx| {
|
||||
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
|
||||
assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement");
|
||||
let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else {
|
||||
unreachable!("toplevel statement must be an expression")
|
||||
};
|
||||
let ExprKind::Call { .. } = expr.node else {
|
||||
unreachable!("toplevel expression must be a function call")
|
||||
};
|
||||
|
||||
let return_obj =
|
||||
generator.gen_expr(ctx, expr)?.map(|value| (expr.custom.unwrap(), value));
|
||||
has_return = return_obj.is_some();
|
||||
registry.wait_tasks_complete(handles);
|
||||
attributes_writeback(
|
||||
ctx,
|
||||
generator,
|
||||
inner_resolver.as_ref(),
|
||||
&host_attributes,
|
||||
return_obj,
|
||||
)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
@ -754,35 +864,23 @@ impl Nac3 {
|
||||
membuffer.lock().push(buffer);
|
||||
});
|
||||
|
||||
embedding_map.setattr("expects_return", has_return).unwrap();
|
||||
|
||||
// Link all modules into `main`.
|
||||
let buffers = membuffers.lock();
|
||||
let main = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(
|
||||
buffers.last().unwrap(),
|
||||
"main",
|
||||
))
|
||||
.unwrap();
|
||||
for buffer in buffers.iter().skip(1) {
|
||||
for buffer in buffers.iter().rev().skip(1) {
|
||||
let other = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
||||
.unwrap();
|
||||
|
||||
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
}
|
||||
let builder = context.create_builder();
|
||||
let modinit_return = main
|
||||
.get_function("__modinit__")
|
||||
.unwrap()
|
||||
.get_last_basic_block()
|
||||
.unwrap()
|
||||
.get_terminator()
|
||||
.unwrap();
|
||||
builder.position_before(&modinit_return);
|
||||
builder
|
||||
.build_call(
|
||||
main.get_function("attributes_writeback").unwrap(),
|
||||
&[],
|
||||
"attributes_writeback",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
|
||||
let mut function_iter = main.get_first_function();
|
||||
@ -817,55 +915,27 @@ impl Nac3 {
|
||||
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
||||
}
|
||||
|
||||
Python::with_gil(|py| {
|
||||
let string_store = self.string_store.read();
|
||||
let mut string_store_vec = string_store.iter().collect::<Vec<_>>();
|
||||
string_store_vec.sort_by(|(_s1, key1), (_s2, key2)| key1.cmp(key2));
|
||||
for (s, key) in string_store_vec {
|
||||
let embed_key: i32 = helper.store_str.call1(py, (s,)).unwrap().extract(py).unwrap();
|
||||
assert_eq!(
|
||||
embed_key, *key,
|
||||
"string {s} is out of sync between embedding map (key={embed_key}) and \
|
||||
the internal string store (key={key})"
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
link_fn(&main)
|
||||
}
|
||||
|
||||
/// Returns the [`TargetTriple`] used for compiling to [isa].
|
||||
fn get_llvm_target_triple(isa: Isa) -> TargetTriple {
|
||||
match isa {
|
||||
Isa::Host => TargetMachine::get_default_triple(),
|
||||
Isa::RiscV32G | Isa::RiscV32IMA => TargetTriple::create("riscv32-unknown-linux"),
|
||||
Isa::CortexA9 => TargetTriple::create("armv7-unknown-linux-gnueabihf"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`String`] representing the target CPU used for compiling to [isa].
|
||||
fn get_llvm_target_cpu(isa: Isa) -> String {
|
||||
match isa {
|
||||
Isa::Host => TargetMachine::get_host_cpu_name().to_string(),
|
||||
Isa::RiscV32G | Isa::RiscV32IMA => "generic-rv32".to_string(),
|
||||
Isa::CortexA9 => "cortex-a9".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the [`String`] representing the target features used for compiling to [isa].
|
||||
fn get_llvm_target_features(isa: Isa) -> String {
|
||||
match isa {
|
||||
Isa::Host => TargetMachine::get_host_cpu_features().to_string(),
|
||||
Isa::RiscV32G => "+a,+m,+f,+d".to_string(),
|
||||
Isa::RiscV32IMA => "+a,+m".to_string(),
|
||||
Isa::CortexA9 => "+dsp,+fp16,+neon,+vfp3,+long-calls".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an instance of [`CodeGenTargetMachineOptions`] representing the target machine
|
||||
/// options used for compiling to [isa].
|
||||
fn get_llvm_target_options(isa: Isa) -> CodeGenTargetMachineOptions {
|
||||
CodeGenTargetMachineOptions {
|
||||
triple: Nac3::get_llvm_target_triple(isa).as_str().to_string_lossy().into_owned(),
|
||||
cpu: Nac3::get_llvm_target_cpu(isa),
|
||||
features: Nac3::get_llvm_target_features(isa),
|
||||
reloc_mode: RelocMode::PIC,
|
||||
..CodeGenTargetMachineOptions::from_host()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an instance of [`TargetMachine`] used in compiling and linking of a program to the
|
||||
/// target [isa].
|
||||
/// target [ISA][isa].
|
||||
fn get_llvm_target_machine(&self) -> TargetMachine {
|
||||
Nac3::get_llvm_target_options(self.isa)
|
||||
.create_target_machine(self.llvm_options.opt_level)
|
||||
.expect("couldn't create target machine")
|
||||
self.isa.create_llvm_target_machine(self.llvm_options.opt_level)
|
||||
}
|
||||
}
|
||||
|
||||
@ -973,7 +1043,8 @@ impl Nac3 {
|
||||
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
|
||||
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
|
||||
};
|
||||
let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type());
|
||||
let (primitive, _) =
|
||||
TopLevelComposer::make_primitives(isa.get_size_type(&Context::create()));
|
||||
let builtins = vec![
|
||||
(
|
||||
"now_mu".into(),
|
||||
@ -1061,11 +1132,54 @@ impl Nac3 {
|
||||
tuple: get_attr_id(builtins_mod, "tuple"),
|
||||
exception: get_attr_id(builtins_mod, "Exception"),
|
||||
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
||||
module: get_attr_id(types_mod, "ModuleType"),
|
||||
};
|
||||
|
||||
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
||||
fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
|
||||
|
||||
let mut string_store: HashMap<String, i32> = HashMap::default();
|
||||
|
||||
// Keep this list of exceptions in sync with `EXCEPTION_ID_LOOKUP` in `artiq::firmware::ksupport::eh_artiq`
|
||||
// The exceptions declared here must be defined in `artiq.coredevice.exceptions`
|
||||
// Verify synchronization by running the test cases in `artiq.test.coredevice.test_exceptions`
|
||||
let runtime_exception_names = [
|
||||
"RTIOUnderflow",
|
||||
"RTIOOverflow",
|
||||
"RTIODestinationUnreachable",
|
||||
"DMAError",
|
||||
"I2CError",
|
||||
"CacheError",
|
||||
"SPIError",
|
||||
"SubkernelError",
|
||||
"0:AssertionError",
|
||||
"0:AttributeError",
|
||||
"0:IndexError",
|
||||
"0:IOError",
|
||||
"0:KeyError",
|
||||
"0:NotImplementedError",
|
||||
"0:OverflowError",
|
||||
"0:RuntimeError",
|
||||
"0:TimeoutError",
|
||||
"0:TypeError",
|
||||
"0:ValueError",
|
||||
"0:ZeroDivisionError",
|
||||
"0:LinAlgError",
|
||||
"UnwrapNoneError",
|
||||
];
|
||||
|
||||
// Preallocate runtime exception names
|
||||
for (i, name) in runtime_exception_names.iter().enumerate() {
|
||||
let exn_name = if name.find(':').is_none() {
|
||||
format!("0:artiq.coredevice.exceptions.{name}")
|
||||
} else {
|
||||
(*name).to_string()
|
||||
};
|
||||
|
||||
let id = i32::try_from(i).unwrap();
|
||||
string_store.insert(exn_name, id);
|
||||
}
|
||||
|
||||
Ok(Nac3 {
|
||||
isa,
|
||||
time_fns,
|
||||
@ -1075,17 +1189,22 @@ impl Nac3 {
|
||||
top_levels: Vec::default(),
|
||||
pyid_to_def: Arc::default(),
|
||||
working_directory,
|
||||
string_store: Arc::default(),
|
||||
string_store: Arc::new(string_store.into()),
|
||||
exception_ids: Arc::default(),
|
||||
deferred_eval_store: DeferredEvaluationStore::new(),
|
||||
llvm_options: CodeGenLLVMOptions {
|
||||
opt_level: OptimizationLevel::Default,
|
||||
target: Nac3::get_llvm_target_options(isa),
|
||||
target: isa.get_llvm_target_options(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
|
||||
fn analyze(
|
||||
&mut self,
|
||||
functions: &PySet,
|
||||
classes: &PySet,
|
||||
content_modules: &PySet,
|
||||
) -> PyResult<()> {
|
||||
let (modules, class_ids) =
|
||||
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
|
||||
let mut modules: HashMap<u64, PyObject> = HashMap::new();
|
||||
@ -1095,14 +1214,22 @@ impl Nac3 {
|
||||
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
|
||||
|
||||
for function in functions {
|
||||
let module = getmodule_fn.call1((function,))?.extract()?;
|
||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||
let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
|
||||
if !module.is_none(py) {
|
||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||
}
|
||||
}
|
||||
for class in classes {
|
||||
let module = getmodule_fn.call1((class,))?.extract()?;
|
||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||
let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
|
||||
if !module.is_none(py) {
|
||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||
}
|
||||
class_ids.insert(id_fn.call1((class,))?.extract()?);
|
||||
}
|
||||
for module in content_modules {
|
||||
let module: PyObject = module.extract()?;
|
||||
modules.insert(id_fn.call1((&module,))?.extract()?, module);
|
||||
}
|
||||
Ok((modules, class_ids))
|
||||
})?;
|
||||
|
||||
|
@ -10,18 +10,20 @@ use itertools::Itertools;
|
||||
use parking_lot::RwLock;
|
||||
use pyo3::{
|
||||
types::{PyDict, PyTuple},
|
||||
PyAny, PyObject, PyResult, Python,
|
||||
PyAny, PyErr, PyObject, PyResult, Python,
|
||||
};
|
||||
|
||||
use super::PrimitivePythonId;
|
||||
use nac3core::{
|
||||
codegen::{
|
||||
classes::{NDArrayType, ProxyType},
|
||||
types::{ndarray::NDArrayType, structure::StructProxyType, ProxyType},
|
||||
values::ndarray::make_contiguous_strides,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
inkwell::{
|
||||
module::Linkage,
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::BasicValueEnum,
|
||||
values::{BasicValue, BasicValueEnum},
|
||||
AddressSpace,
|
||||
},
|
||||
nac3parser::ast::{self, StrRef},
|
||||
@ -37,8 +39,6 @@ use nac3core::{
|
||||
},
|
||||
};
|
||||
|
||||
use super::PrimitivePythonId;
|
||||
|
||||
pub enum PrimitiveValue {
|
||||
I32(i32),
|
||||
I64(i64),
|
||||
@ -674,6 +674,48 @@ impl InnerResolver {
|
||||
})
|
||||
});
|
||||
|
||||
// check if obj is module
|
||||
if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::<u64>(py)?
|
||||
== self.primitive_ids.module
|
||||
&& self.pyid_to_def.read().contains_key(&py_obj_id)
|
||||
{
|
||||
let def_id = self.pyid_to_def.read()[&py_obj_id];
|
||||
let def = defs[def_id.0].read();
|
||||
let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } =
|
||||
&*def
|
||||
else {
|
||||
unreachable!("must be a module here");
|
||||
};
|
||||
// Construct the module return type
|
||||
let mut module_attributes = HashMap::new();
|
||||
for (name, _) in attributes {
|
||||
let attribute_obj = obj.getattr(name.to_string().as_str())?;
|
||||
let attribute_ty =
|
||||
self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?;
|
||||
if let Ok(attribute_ty) = attribute_ty {
|
||||
module_attributes.insert(*name, (attribute_ty, false));
|
||||
} else {
|
||||
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
|
||||
}
|
||||
}
|
||||
|
||||
for name in methods.keys() {
|
||||
let method_obj = obj.getattr(name.to_string().as_str())?;
|
||||
let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?;
|
||||
if let Ok(method_ty) = method_ty {
|
||||
module_attributes.insert(*name, (method_ty, true));
|
||||
} else {
|
||||
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
|
||||
}
|
||||
}
|
||||
|
||||
let module_ty =
|
||||
TypeEnum::TModule { module_id: *module_id, attributes: module_attributes };
|
||||
|
||||
let ty = unifier.add_ty(module_ty);
|
||||
return Ok(Ok(ty));
|
||||
}
|
||||
|
||||
if let Some(ty) = constructor_ty {
|
||||
self.pyid_to_type.write().insert(py_obj_id, ty);
|
||||
return Ok(Ok(ty));
|
||||
@ -931,10 +973,13 @@ impl InnerResolver {
|
||||
|_| Ok(Ok(extracted_ty)),
|
||||
)
|
||||
} else if unifier.unioned(extracted_ty, primitives.bool) {
|
||||
obj.extract::<bool>().map_or_else(
|
||||
|_| Ok(Err(format!("{obj} is not in the range of bool"))),
|
||||
|_| Ok(Ok(extracted_ty)),
|
||||
)
|
||||
if obj.extract::<bool>().is_ok()
|
||||
|| obj.call_method("__bool__", (), None)?.extract::<bool>().is_ok()
|
||||
{
|
||||
Ok(Ok(extracted_ty))
|
||||
} else {
|
||||
Ok(Err(format!("{obj} is not in the range of bool")))
|
||||
}
|
||||
} else if unifier.unioned(extracted_ty, primitives.float) {
|
||||
obj.extract::<f64>().map_or_else(
|
||||
|_| Ok(Err(format!("{obj} is not in the range of float64"))),
|
||||
@ -974,10 +1019,14 @@ impl InnerResolver {
|
||||
let val: u64 = obj.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val));
|
||||
Ok(Some(ctx.ctx.i64_type().const_int(val, false).into()))
|
||||
} else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ {
|
||||
} else if ty_id == self.primitive_ids.bool {
|
||||
let val: bool = obj.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
|
||||
Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
|
||||
} else if ty_id == self.primitive_ids.np_bool_ {
|
||||
let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
|
||||
Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
|
||||
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
||||
let val: String = obj.extract().unwrap();
|
||||
self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone()));
|
||||
@ -1000,7 +1049,7 @@ impl InnerResolver {
|
||||
}
|
||||
_ => unreachable!("must be list"),
|
||||
};
|
||||
let size_t = generator.get_size_type(ctx.ctx);
|
||||
let size_t = ctx.get_size_type();
|
||||
let ty = if len == 0
|
||||
&& matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. })
|
||||
{
|
||||
@ -1085,18 +1134,19 @@ impl InnerResolver {
|
||||
} else {
|
||||
unreachable!("must be ndarray")
|
||||
};
|
||||
let (ndarray_dtype, ndarray_ndims) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty);
|
||||
let dtype = llvm_ndarray.element_type();
|
||||
|
||||
{
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(
|
||||
ndarray_llvm_ty.as_underlying_type(),
|
||||
llvm_ndarray.as_abi_type().get_element_type().into_struct_type(),
|
||||
Some(AddressSpace::default()),
|
||||
&id_str,
|
||||
)
|
||||
@ -1106,40 +1156,44 @@ impl InnerResolver {
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||
else {
|
||||
unreachable!("Expected Literal for ndarray_ndims")
|
||||
};
|
||||
|
||||
let ndarray_ndims = if values.len() == 1 {
|
||||
values[0].clone()
|
||||
} else {
|
||||
todo!("Unpacking literal of more than one element unimplemented")
|
||||
};
|
||||
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
|
||||
unreachable!("Expected u64 value for ndarray_ndims")
|
||||
};
|
||||
let ndims = llvm_ndarray.ndims();
|
||||
|
||||
// Obtain the shape of the ndarray
|
||||
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
||||
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
|
||||
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
|
||||
assert_eq!(shape_tuple.len(), ndims as usize);
|
||||
|
||||
// The Rust type inferencer cannot figure this out
|
||||
let shape_values = shape_tuple
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, elem)| {
|
||||
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|
||||
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
|
||||
)
|
||||
let value = self
|
||||
.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize())
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||
})?
|
||||
.unwrap();
|
||||
let value = ctx
|
||||
.builder
|
||||
.build_int_z_extend(value.into_int_value(), llvm_usize, "")
|
||||
.unwrap();
|
||||
Ok(value)
|
||||
})
|
||||
.collect();
|
||||
let shape_values = shape_values?.unwrap();
|
||||
let shape_values = llvm_usize.const_array(
|
||||
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
|
||||
);
|
||||
.collect::<Result<Vec<_>, PyErr>>()?;
|
||||
|
||||
// Also use this opportunity to get the constant values of `shape_values` for calculating strides.
|
||||
let shape_u64s = shape_values
|
||||
.iter()
|
||||
.map(|dim| {
|
||||
assert!(dim.is_const());
|
||||
dim.get_zero_extended_constant().unwrap()
|
||||
})
|
||||
.collect_vec();
|
||||
let shape_values = llvm_usize.const_array(&shape_values);
|
||||
|
||||
// create a global for ndarray.shape and initialize it using the shape
|
||||
let shape_global = ctx.module.add_global(
|
||||
llvm_usize.array_type(ndarray_ndims as u32),
|
||||
llvm_usize.array_type(ndims as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&(id_str.clone() + ".shape"),
|
||||
);
|
||||
@ -1147,17 +1201,25 @@ impl InnerResolver {
|
||||
|
||||
// Obtain the (flattened) elements of the ndarray
|
||||
let sz: usize = obj.getattr("size")?.extract()?;
|
||||
let data: Result<Option<Vec<_>>, _> = (0..sz)
|
||||
let data: Vec<_> = (0..sz)
|
||||
.map(|i| {
|
||||
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
||||
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
|
||||
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||
})
|
||||
let value = self
|
||||
.get_obj_value(py, elem, ctx, generator, ndarray_dtype)
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!(
|
||||
"Error getting element {i}: {e}"
|
||||
))
|
||||
})?
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(value.get_type(), dtype);
|
||||
Ok(value)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let data = data?.unwrap().into_iter();
|
||||
let data = match ndarray_dtype_llvm_ty {
|
||||
.try_collect()?;
|
||||
let data = data.into_iter();
|
||||
let data = match dtype {
|
||||
BasicTypeEnum::ArrayType(ty) => {
|
||||
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
|
||||
}
|
||||
@ -1182,34 +1244,93 @@ impl InnerResolver {
|
||||
};
|
||||
|
||||
// create a global for ndarray.data and initialize it using the elements
|
||||
//
|
||||
// NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`.
|
||||
// We will have to cast it to an `u8*` later.
|
||||
let data_global = ctx.module.add_global(
|
||||
ndarray_dtype_llvm_ty.array_type(sz as u32),
|
||||
dtype.array_type(sz as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&(id_str.clone() + ".data"),
|
||||
);
|
||||
data_global.set_initializer(&data);
|
||||
|
||||
// Get the constant itemsize.
|
||||
//
|
||||
// NOTE: dtype.size_of() may return a non-constant, where `TargetData::get_store_size`
|
||||
// will always return a constant size.
|
||||
let itemsize = ctx
|
||||
.registry
|
||||
.llvm_options
|
||||
.create_target_machine()
|
||||
.map(|tm| tm.get_target_data().get_store_size(&dtype))
|
||||
.unwrap();
|
||||
assert_ne!(itemsize, 0);
|
||||
|
||||
// Create the strides needed for ndarray.strides
|
||||
let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s);
|
||||
let strides =
|
||||
strides.into_iter().map(|stride| llvm_usize.const_int(stride, false)).collect_vec();
|
||||
let strides = llvm_usize.const_array(&strides);
|
||||
|
||||
// create a global for ndarray.strides and initialize it
|
||||
let strides_global = ctx.module.add_global(
|
||||
llvm_usize.array_type(ndims as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&format!("${id_str}.strides"),
|
||||
);
|
||||
strides_global.set_initializer(&strides);
|
||||
|
||||
// create a global for the ndarray object and initialize it
|
||||
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[
|
||||
llvm_usize.const_int(ndarray_ndims, false).into(),
|
||||
shape_global
|
||||
.as_pointer_value()
|
||||
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
|
||||
.into(),
|
||||
data_global
|
||||
.as_pointer_value()
|
||||
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
|
||||
.into(),
|
||||
|
||||
// NOTE: data_global is an array of dtype, we want a `u8*`.
|
||||
let ndarray_data = data_global.as_pointer_value();
|
||||
let ndarray_data = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
||||
|
||||
let ndarray_itemsize = llvm_usize.const_int(itemsize, false);
|
||||
|
||||
let ndarray_ndims = llvm_usize.const_int(ndims, false);
|
||||
|
||||
// calling as_pointer_value on shape and strides returns [i64 x ndims]*
|
||||
// convert into i64* to conform with expected layout of ndarray
|
||||
|
||||
let ndarray_shape = shape_global.as_pointer_value();
|
||||
let ndarray_shape = unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
ndarray_shape,
|
||||
&[llvm_usize.const_zero(), llvm_usize.const_zero()],
|
||||
"",
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let ndarray_strides = strides_global.as_pointer_value();
|
||||
let ndarray_strides = unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
ndarray_strides,
|
||||
&[llvm_usize.const_zero(), llvm_usize.const_zero()],
|
||||
"",
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let ndarray = llvm_ndarray.get_struct_type().const_named_struct(&[
|
||||
ndarray_itemsize.into(),
|
||||
ndarray_ndims.into(),
|
||||
ndarray_shape.into(),
|
||||
ndarray_strides.into(),
|
||||
ndarray_data.into(),
|
||||
]);
|
||||
|
||||
let ndarray = ctx.module.add_global(
|
||||
ndarray_llvm_ty.as_underlying_type(),
|
||||
let ndarray_global = ctx.module.add_global(
|
||||
llvm_ndarray.as_abi_type().get_element_type().into_struct_type(),
|
||||
Some(AddressSpace::default()),
|
||||
&id_str,
|
||||
);
|
||||
ndarray.set_initializer(&value);
|
||||
ndarray_global.set_initializer(&ndarray);
|
||||
|
||||
Ok(Some(ndarray.as_pointer_value().into()))
|
||||
Ok(Some(ndarray_global.as_pointer_value().into()))
|
||||
} else if ty_id == self.primitive_ids.tuple {
|
||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||
let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else {
|
||||
@ -1290,6 +1411,77 @@ impl InnerResolver {
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
} else if ty_id == self.primitive_ids.module {
|
||||
let id_str = id.to_string();
|
||||
|
||||
if let Some(global) = ctx.module.get_global(&id_str) {
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
}
|
||||
|
||||
let top_level_defs = ctx.top_level.definitions.read();
|
||||
let ty = self
|
||||
.get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
|
||||
.unwrap();
|
||||
let ty = ctx
|
||||
.get_llvm_type(generator, ty)
|
||||
.into_pointer_type()
|
||||
.get_element_type()
|
||||
.into_struct_type();
|
||||
|
||||
{
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
}
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
|
||||
let fields = {
|
||||
let definition =
|
||||
top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read();
|
||||
let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() };
|
||||
attributes
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
let definition = top_level_defs.get(f.1 .0).unwrap().read();
|
||||
if let TopLevelDef::Variable { ty, .. } = &*definition {
|
||||
Some((f.0, *ty))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect_vec()
|
||||
};
|
||||
|
||||
let values: Result<Option<Vec<_>>, _> = fields
|
||||
.iter()
|
||||
.map(|(name, ty)| {
|
||||
self.get_obj_value(
|
||||
py,
|
||||
obj.getattr(name.to_string().as_str())?,
|
||||
ctx,
|
||||
generator,
|
||||
*ty,
|
||||
)
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!("Error getting field {name}: {e}"))
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let values = values?;
|
||||
|
||||
if let Some(values) = values {
|
||||
let val = ty.const_named_struct(&values);
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
|
||||
});
|
||||
global.set_initializer(&val);
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
let id_str = id.to_string();
|
||||
|
||||
@ -1369,9 +1561,12 @@ impl InnerResolver {
|
||||
} else if ty_id == self.primitive_ids.uint64 {
|
||||
let val: u64 = obj.extract()?;
|
||||
Ok(SymbolValue::U64(val))
|
||||
} else if ty_id == self.primitive_ids.bool || ty_id == self.primitive_ids.np_bool_ {
|
||||
} else if ty_id == self.primitive_ids.bool {
|
||||
let val: bool = obj.extract()?;
|
||||
Ok(SymbolValue::Bool(val))
|
||||
} else if ty_id == self.primitive_ids.np_bool_ {
|
||||
let val: bool = obj.call_method("__bool__", (), None)?.extract()?;
|
||||
Ok(SymbolValue::Bool(val))
|
||||
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
|
||||
let val: String = obj.extract()?;
|
||||
Ok(SymbolValue::Str(val))
|
||||
@ -1469,9 +1664,50 @@ impl SymbolResolver for Resolver {
|
||||
fn get_symbol_value<'ctx>(
|
||||
&self,
|
||||
id: StrRef,
|
||||
_: &mut CodeGenContext<'ctx, '_>,
|
||||
_: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Option<ValueEnum<'ctx>> {
|
||||
if let Some(def_id) = self.0.id_to_def.read().get(&id) {
|
||||
let top_levels = ctx.top_level.definitions.read();
|
||||
if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) {
|
||||
let module_val = &self.0.module;
|
||||
let ret = Python::with_gil(|py| -> PyResult<Result<BasicValueEnum, String>> {
|
||||
let module_val = module_val.as_ref(py);
|
||||
|
||||
let ty = self.0.get_obj_type(
|
||||
py,
|
||||
module_val,
|
||||
&mut ctx.unifier,
|
||||
&top_levels,
|
||||
&ctx.primitives,
|
||||
)?;
|
||||
if let Err(ty) = ty {
|
||||
return Ok(Err(ty));
|
||||
}
|
||||
let ty = ty.unwrap();
|
||||
let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
|
||||
let (idx, _) = ctx.get_attr_index(ty, id);
|
||||
let ret = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
obj.into_pointer_value(),
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(idx as u64, false),
|
||||
],
|
||||
id.to_string().as_str(),
|
||||
)
|
||||
}
|
||||
.unwrap();
|
||||
Ok(Ok(ret.as_basic_value_enum()))
|
||||
})
|
||||
.unwrap();
|
||||
if ret.is_err() {
|
||||
return None;
|
||||
}
|
||||
return Some(ret.unwrap().into());
|
||||
}
|
||||
}
|
||||
|
||||
let sym_value = {
|
||||
let id_to_val = self.0.id_to_pyval.read();
|
||||
id_to_val.get(&id).cloned()
|
||||
@ -1532,10 +1768,7 @@ impl SymbolResolver for Resolver {
|
||||
if let Some(id) = string_store.get(s) {
|
||||
*id
|
||||
} else {
|
||||
let id = Python::with_gil(|py| -> PyResult<i32> {
|
||||
self.0.helper.store_str.call1(py, (s,))?.extract(py)
|
||||
})
|
||||
.unwrap();
|
||||
let id = i32::try_from(string_store.len()).unwrap();
|
||||
string_store.insert(s.into(), id);
|
||||
id
|
||||
}
|
||||
|
@ -10,7 +10,6 @@ constant-optimization = ["fold"]
|
||||
fold = []
|
||||
|
||||
[dependencies]
|
||||
lazy_static = "1.5"
|
||||
parking_lot = "0.12"
|
||||
string-interner = "0.17"
|
||||
string-interner = "0.18"
|
||||
fxhash = "0.2"
|
||||
|
@ -5,14 +5,12 @@ pub use crate::location::Location;
|
||||
|
||||
use fxhash::FxBuildHasher;
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use std::{cell::RefCell, collections::HashMap, fmt};
|
||||
use std::{cell::RefCell, collections::HashMap, fmt, sync::LazyLock};
|
||||
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
|
||||
|
||||
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
|
||||
lazy_static! {
|
||||
static ref INTERNER: Mutex<Interner> =
|
||||
Mutex::new(StringInterner::with_hasher(FxBuildHasher::default()));
|
||||
}
|
||||
static INTERNER: LazyLock<Mutex<Interner>> =
|
||||
LazyLock::new(|| Mutex::new(StringInterner::with_hasher(FxBuildHasher::default())));
|
||||
|
||||
thread_local! {
|
||||
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();
|
||||
|
@ -1,10 +1,4 @@
|
||||
#![deny(
|
||||
future_incompatible,
|
||||
let_underscore,
|
||||
nonstandard_style,
|
||||
rust_2024_compatibility,
|
||||
clippy::all
|
||||
)]
|
||||
#![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![allow(
|
||||
clippy::missing_errors_doc,
|
||||
@ -14,9 +8,6 @@
|
||||
clippy::wildcard_imports
|
||||
)]
|
||||
|
||||
#[macro_use]
|
||||
extern crate lazy_static;
|
||||
|
||||
mod ast_gen;
|
||||
mod constant;
|
||||
#[cfg(feature = "fold")]
|
||||
|
@ -5,14 +5,16 @@ authors = ["M-Labs"]
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["derive"]
|
||||
derive = ["dep:nac3core_derive"]
|
||||
no-escape-analysis = []
|
||||
|
||||
[dependencies]
|
||||
itertools = "0.13"
|
||||
itertools = "0.14"
|
||||
crossbeam = "0.8"
|
||||
indexmap = "2.6"
|
||||
indexmap = "2.7"
|
||||
parking_lot = "0.12"
|
||||
rayon = "1.10"
|
||||
nac3core_derive = { path = "nac3core_derive", optional = true }
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
strum = "0.26"
|
||||
strum_macros = "0.26"
|
||||
@ -28,4 +30,4 @@ indoc = "2.0"
|
||||
insta = "=1.11.0"
|
||||
|
||||
[build-dependencies]
|
||||
regex = "1.10"
|
||||
regex = "1.11"
|
||||
|
@ -56,9 +56,8 @@ fn main() {
|
||||
let output = Command::new("clang-irrt")
|
||||
.args(flags)
|
||||
.output()
|
||||
.map(|o| {
|
||||
.inspect(|o| {
|
||||
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
|
||||
o
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
|
@ -1,6 +1,15 @@
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/list.hpp"
|
||||
#include "irrt/math.hpp"
|
||||
#include "irrt/ndarray.hpp"
|
||||
#include "irrt/range.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
#include "irrt/string.hpp"
|
||||
#include "irrt/ndarray/basic.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
#include "irrt/ndarray/iter.hpp"
|
||||
#include "irrt/ndarray/indexing.hpp"
|
||||
#include "irrt/ndarray/array.hpp"
|
||||
#include "irrt/ndarray/reshape.hpp"
|
||||
#include "irrt/ndarray/broadcast.hpp"
|
||||
#include "irrt/ndarray/transpose.hpp"
|
||||
#include "irrt/ndarray/matmul.hpp"
|
@ -4,6 +4,6 @@
|
||||
|
||||
template<typename SizeT>
|
||||
struct CSlice {
|
||||
uint8_t* base;
|
||||
void* base;
|
||||
SizeT len;
|
||||
};
|
@ -6,7 +6,7 @@
|
||||
/**
|
||||
* @brief The int type of ARTIQ exception IDs.
|
||||
*/
|
||||
typedef int32_t ExceptionId;
|
||||
using ExceptionId = int32_t;
|
||||
|
||||
/*
|
||||
* Set of exceptions C++ IRRT can use.
|
||||
@ -55,11 +55,14 @@ void _raise_exception_helper(ExceptionId id,
|
||||
int64_t param2) {
|
||||
Exception<SizeT> e = {
|
||||
.id = id,
|
||||
.filename = {.base = reinterpret_cast<const uint8_t*>(filename), .len = __builtin_strlen(filename)},
|
||||
.filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
|
||||
.len = static_cast<SizeT>(__builtin_strlen(filename))},
|
||||
.line = line,
|
||||
.column = 0,
|
||||
.function = {.base = reinterpret_cast<const uint8_t*>(function), .len = __builtin_strlen(function)},
|
||||
.msg = {.base = reinterpret_cast<const uint8_t*>(msg), .len = __builtin_strlen(msg)},
|
||||
.function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
|
||||
.len = static_cast<SizeT>(__builtin_strlen(function))},
|
||||
.msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
|
||||
.len = static_cast<SizeT>(__builtin_strlen(msg))},
|
||||
};
|
||||
e.params[0] = param0;
|
||||
e.params[1] = param1;
|
||||
@ -67,6 +70,7 @@ void _raise_exception_helper(ExceptionId id,
|
||||
__nac3_raise(reinterpret_cast<void*>(&e));
|
||||
__builtin_unreachable();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* @brief Raise an exception with location details (location in the IRRT source files).
|
||||
@ -79,4 +83,3 @@ void _raise_exception_helper(ExceptionId id,
|
||||
*/
|
||||
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
|
||||
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
|
||||
} // namespace
|
@ -8,15 +8,18 @@ using uint32_t = unsigned _BitInt(32);
|
||||
using int64_t = _BitInt(64);
|
||||
using uint64_t = unsigned _BitInt(64);
|
||||
#else
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wdeprecated-type"
|
||||
using int8_t = _ExtInt(8);
|
||||
using uint8_t = unsigned _ExtInt(8);
|
||||
using int32_t = _ExtInt(32);
|
||||
using uint32_t = unsigned _ExtInt(32);
|
||||
using int64_t = _ExtInt(64);
|
||||
using uint64_t = unsigned _ExtInt(64);
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
#endif
|
||||
|
||||
// NDArray indices are always `uint32_t`.
|
||||
using NDIndex = uint32_t;
|
||||
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
||||
using SliceIndex = int32_t;
|
||||
|
@ -2,6 +2,21 @@
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/math_util.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* @brief A list in NAC3.
|
||||
*
|
||||
* The `items` field is opaque. You must rely on external contexts to
|
||||
* know how to interpret it.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
struct List {
|
||||
uint8_t* items;
|
||||
SizeT len;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
// Handle list assignment and dropping part of the list when
|
||||
@ -13,12 +28,12 @@ extern "C" {
|
||||
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||
SliceIndex dest_end,
|
||||
SliceIndex dest_step,
|
||||
uint8_t* dest_arr,
|
||||
void* dest_arr,
|
||||
SliceIndex dest_arr_len,
|
||||
SliceIndex src_start,
|
||||
SliceIndex src_end,
|
||||
SliceIndex src_step,
|
||||
uint8_t* src_arr,
|
||||
void* src_arr,
|
||||
SliceIndex src_arr_len,
|
||||
const SliceIndex size) {
|
||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
||||
@ -29,11 +44,13 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
||||
if (src_len > 0) {
|
||||
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
|
||||
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
|
||||
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
|
||||
}
|
||||
if (dest_len > 0) {
|
||||
/* dropping */
|
||||
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
|
||||
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
|
||||
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
|
||||
(dest_arr_len - dest_end - 1) * size);
|
||||
}
|
||||
/* shrink size */
|
||||
@ -44,7 +61,7 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|
||||
|| max(src_start, src_end) < min(dest_start, dest_end));
|
||||
if (need_alloca) {
|
||||
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
|
||||
void* tmp = __builtin_alloca(src_arr_len * size);
|
||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
||||
src_arr = tmp;
|
||||
}
|
||||
@ -53,20 +70,24 @@ SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
|
||||
/* for constant optimization */
|
||||
if (size == 1) {
|
||||
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
|
||||
} else if (size == 4) {
|
||||
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
|
||||
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
|
||||
} else if (size == 8) {
|
||||
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
|
||||
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
|
||||
} else {
|
||||
/* memcpy for var size, cannot overlap after previous alloca */
|
||||
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
||||
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
|
||||
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
|
||||
}
|
||||
}
|
||||
/* only dest_step == 1 can we shrink the dest list. */
|
||||
/* size should be ensured prior to calling this function */
|
||||
if (dest_step == 1 && dest_end >= dest_start) {
|
||||
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
|
||||
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
|
||||
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
|
||||
(dest_arr_len - dest_end - 1) * size);
|
||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
// need to make sure `exp >= 0` before calling this function
|
||||
@ -90,4 +92,4 @@ double __nac3_j0(double x) {
|
||||
|
||||
return j0(x);
|
||||
}
|
||||
}
|
||||
} // namespace
|
@ -1,144 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||
__builtin_assume(end_idx <= list_len);
|
||||
|
||||
SizeT num_elems = 1;
|
||||
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||
SizeT val = list_data[i];
|
||||
__builtin_assume(val > 0);
|
||||
num_elems *= val;
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
|
||||
SizeT stride = 1;
|
||||
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||
SizeT i = num_dims - dim - 1;
|
||||
__builtin_assume(dims[i] > 0);
|
||||
idxs[i] = (index / stride) % dims[i];
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
|
||||
SizeT idx = 0;
|
||||
SizeT stride = 1;
|
||||
for (SizeT i = 0; i < num_dims; ++i) {
|
||||
SizeT ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += stride * indices[ri];
|
||||
}
|
||||
|
||||
__builtin_assume(dims[i] > 0);
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
||||
SizeT lhs_ndims,
|
||||
const SizeT* rhs_dims,
|
||||
SizeT rhs_ndims,
|
||||
SizeT* out_dims) {
|
||||
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == nullptr) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == nullptr) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
|
||||
SizeT src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx) {
|
||||
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||
SizeT src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
uint64_t
|
||||
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
uint32_t
|
||||
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
uint64_t
|
||||
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
||||
uint32_t lhs_ndims,
|
||||
const uint32_t* rhs_dims,
|
||||
uint32_t rhs_ndims,
|
||||
uint32_t* out_dims) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
|
||||
uint64_t lhs_ndims,
|
||||
const uint64_t* rhs_dims,
|
||||
uint64_t rhs_ndims,
|
||||
uint64_t* out_dims) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
|
||||
uint32_t src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
|
||||
uint64_t src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
}
|
132
nac3core/irrt/irrt/ndarray/array.hpp
Normal file
132
nac3core/irrt/irrt/ndarray/array.hpp
Normal file
@ -0,0 +1,132 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/list.hpp"
|
||||
#include "irrt/ndarray/basic.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
|
||||
namespace {
|
||||
namespace ndarray::array {
|
||||
/**
|
||||
* @brief In the context of `np.array(<list>)`, deduce the ndarray's shape produced by `<list>` and raise
|
||||
* an exception if there is anything wrong with `<shape>` (e.g., inconsistent dimensions `np.array([[1.0, 2.0],
|
||||
* [3.0]])`)
|
||||
*
|
||||
* If this function finds no issues with `<list>`, the deduced shape is written to `shape`. The caller has the
|
||||
* responsibility to allocate `[SizeT; ndims]` for `shape`. The caller must also initialize `shape` with `-1`s because
|
||||
* of implementation details.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list, SizeT ndims, SizeT* shape) {
|
||||
if (shape[axis] == -1) {
|
||||
// Dimension is unspecified. Set it.
|
||||
shape[axis] = list->len;
|
||||
} else {
|
||||
// Dimension is specified. Check.
|
||||
if (shape[axis] != list->len) {
|
||||
// Mismatch, throw an error.
|
||||
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||
"The requested array has an inhomogenous shape "
|
||||
"after {0} dimension(s).",
|
||||
axis, shape[axis], list->len);
|
||||
}
|
||||
}
|
||||
|
||||
if (axis + 1 == ndims) {
|
||||
// `list` has type `list[ItemType]`
|
||||
// Do nothing
|
||||
} else {
|
||||
// `list` has type `list[list[...]]`
|
||||
List<SizeT>** lists = (List<SizeT>**)(list->items);
|
||||
for (SizeT i = 0; i < list->len; i++) {
|
||||
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims, shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief See `set_and_validate_list_shape_helper`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
|
||||
}
|
||||
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief In the context of `np.array(<list>)`, copied the contents stored in `list` to `ndarray`.
|
||||
*
|
||||
* `list` is assumed to be "legal". (i.e., no inconsistent dimensions)
|
||||
*
|
||||
* # Notes on `ndarray`
|
||||
* The caller is responsible for allocating space for `ndarray`.
|
||||
* Here is what this function expects from `ndarray` when called:
|
||||
* - `ndarray->data` has to be allocated, contiguous, and may contain uninitialized values.
|
||||
* - `ndarray->itemsize` has to be initialized.
|
||||
* - `ndarray->ndims` has to be initialized.
|
||||
* - `ndarray->shape` has to be initialized.
|
||||
* - `ndarray->strides` is ignored, but note that `ndarray->data` is contiguous.
|
||||
* When this function call ends:
|
||||
* - `ndarray->data` is written with contents from `<list>`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list, NDArray<SizeT>* ndarray) {
|
||||
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
|
||||
if (IRRT_DEBUG_ASSERT_BOOL) {
|
||||
if (!ndarray::basic::is_c_contiguous(ndarray)) {
|
||||
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0], ndarray->strides[1],
|
||||
NO_PARAM);
|
||||
}
|
||||
}
|
||||
|
||||
if (axis + 1 == ndarray->ndims) {
|
||||
// `list` has type `list[scalar]`
|
||||
// `ndarray` is contiguous, so we can do this, and this is fast.
|
||||
uint8_t* dst = static_cast<uint8_t*>(ndarray->data) + (ndarray->itemsize * (*index));
|
||||
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
|
||||
*index += list->len;
|
||||
} else {
|
||||
// `list` has type `list[list[...]]`
|
||||
List<SizeT>** lists = (List<SizeT>**)(list->items);
|
||||
|
||||
for (SizeT i = 0; i < list->len; i++) {
|
||||
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i], ndarray);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief See `write_list_to_array_helper`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
|
||||
SizeT index = 0;
|
||||
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
|
||||
}
|
||||
} // namespace ndarray::array
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::array;
|
||||
|
||||
void __nac3_ndarray_array_set_and_validate_list_shape(List<int32_t>* list, int32_t ndims, int32_t* shape) {
|
||||
set_and_validate_list_shape(list, ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_array_set_and_validate_list_shape64(List<int64_t>* list, int64_t ndims, int64_t* shape) {
|
||||
set_and_validate_list_shape(list, ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_array_write_list_to_array(List<int32_t>* list, NDArray<int32_t>* ndarray) {
|
||||
write_list_to_array(list, ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_array_write_list_to_array64(List<int64_t>* list, NDArray<int64_t>* ndarray) {
|
||||
write_list_to_array(list, ndarray);
|
||||
}
|
||||
}
|
340
nac3core/irrt/irrt/ndarray/basic.hpp
Normal file
340
nac3core/irrt/irrt/ndarray/basic.hpp
Normal file
@ -0,0 +1,340 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
|
||||
namespace {
|
||||
namespace ndarray::basic {
|
||||
/**
|
||||
* @brief Assert that `shape` does not contain negative dimensions.
|
||||
*
|
||||
* @param ndims Number of dimensions in `shape`
|
||||
* @param shape The shape to check on
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
if (shape[axis] < 0) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||
"negative dimensions are not allowed; axis {0} "
|
||||
"has dimension {1}",
|
||||
axis, shape[axis], NO_PARAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void assert_output_shape_same(SizeT ndarray_ndims,
|
||||
const SizeT* ndarray_shape,
|
||||
SizeT output_ndims,
|
||||
const SizeT* output_shape) {
|
||||
if (ndarray_ndims != output_ndims) {
|
||||
// There is no corresponding NumPy error message like this.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
|
||||
output_ndims, ndarray_ndims, NO_PARAM);
|
||||
}
|
||||
|
||||
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
|
||||
if (ndarray_shape[axis] != output_shape[axis]) {
|
||||
// There is no corresponding NumPy error message like this.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||
"Mismatched dimensions on axis {0}, output has "
|
||||
"dimension {1}, but destination ndarray has dimension {2}.",
|
||||
axis, output_shape[axis], ndarray_shape[axis]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the number of elements of an ndarray given its shape.
|
||||
*
|
||||
* @param ndims Number of dimensions in `shape`
|
||||
* @param shape The shape of the ndarray
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
|
||||
SizeT size = 1;
|
||||
for (SizeT axis = 0; axis < ndims; axis++)
|
||||
size *= shape[axis];
|
||||
return size;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
|
||||
*
|
||||
* @param ndims Number of elements in `shape` and `indices`
|
||||
* @param shape The shape of the ndarray
|
||||
* @param indices The returned indices indexing the ndarray with shape `shape`.
|
||||
* @param nth The index of the element of interest.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
SizeT axis = ndims - i - 1;
|
||||
SizeT dim = shape[axis];
|
||||
|
||||
indices[axis] = nth % dim;
|
||||
nth /= dim;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the number of elements of an `ndarray`
|
||||
*
|
||||
* This function corresponds to `<an_ndarray>.size`
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT size(const NDArray<SizeT>* ndarray) {
|
||||
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return of the number of its content of an `ndarray`.
|
||||
*
|
||||
* This function corresponds to `<an_ndarray>.nbytes`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT nbytes(const NDArray<SizeT>* ndarray) {
|
||||
return size(ndarray) * ndarray->itemsize;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
|
||||
*
|
||||
* This function corresponds to `<an_ndarray>.__len__`.
|
||||
*
|
||||
* @param dst_length The length.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
SizeT len(const NDArray<SizeT>* ndarray) {
|
||||
if (ndarray->ndims != 0) {
|
||||
return ndarray->shape[0];
|
||||
}
|
||||
|
||||
// numpy prohibits `__len__` on unsized objects
|
||||
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
|
||||
*
|
||||
* You may want to see ndarray's rules for C-contiguity:
|
||||
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||
*/
|
||||
template<typename SizeT>
|
||||
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
|
||||
// References:
|
||||
// - tinynumpy's implementation:
|
||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
|
||||
// - ndarray's flags["C_CONTIGUOUS"]:
|
||||
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
|
||||
// - ndarray's rules for C-contiguity:
|
||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
|
||||
|
||||
// From
|
||||
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
|
||||
//
|
||||
// The traditional rule is that for an array to be flagged as C contiguous,
|
||||
// the following must hold:
|
||||
//
|
||||
// strides[-1] == itemsize
|
||||
// strides[i] == shape[i+1] * strides[i + 1]
|
||||
// [...]
|
||||
// According to these rules, a 0- or 1-dimensional array is either both
|
||||
// C- and F-contiguous, or neither; and an array with 2+ dimensions
|
||||
// can be C- or F- contiguous, or neither, but not both. Though there
|
||||
// there are exceptions for arrays with zero or one item, in the first
|
||||
// case the check is relaxed up to and including the first dimension
|
||||
// with shape[i] == 0. In the second case `strides == itemsize` will
|
||||
// can be true for all dimensions and both flags are set.
|
||||
|
||||
if (ndarray->ndims == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (SizeT i = 1; i < ndarray->ndims; i++) {
|
||||
SizeT axis_i = ndarray->ndims - i - 1;
|
||||
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
|
||||
*
|
||||
* This function does no bound check.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
|
||||
void* element = ndarray->data;
|
||||
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
|
||||
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
|
||||
return element;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
|
||||
*
|
||||
* This function does no bound check.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
|
||||
void* element = ndarray->data;
|
||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||
SizeT axis = ndarray->ndims - i - 1;
|
||||
SizeT dim = ndarray->shape[axis];
|
||||
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
|
||||
nth /= dim;
|
||||
}
|
||||
return element;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
|
||||
*
|
||||
* You might want to read https://ajcr.net/stride-guide-part-1/.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
|
||||
SizeT stride_product = 1;
|
||||
for (SizeT i = 0; i < ndarray->ndims; i++) {
|
||||
SizeT axis = ndarray->ndims - i - 1;
|
||||
ndarray->strides[axis] = stride_product * ndarray->itemsize;
|
||||
stride_product *= ndarray->shape[axis];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Set an element in `ndarray`.
|
||||
*
|
||||
* @param pelement Pointer to the element in `ndarray` to be set.
|
||||
* @param pvalue Pointer to the value `pelement` will be set to.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
|
||||
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
|
||||
*
|
||||
* Both ndarrays will be viewed in their flatten views when copying the elements.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
// TODO: Make this faster with memcpy when we see a contiguous segment.
|
||||
// TODO: Handle overlapping.
|
||||
|
||||
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
|
||||
|
||||
for (SizeT i = 0; i < size(src_ndarray); i++) {
|
||||
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
||||
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
||||
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
||||
}
|
||||
}
|
||||
} // namespace ndarray::basic
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::basic;
|
||||
|
||||
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
|
||||
assert_shape_no_negative(ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
|
||||
assert_shape_no_negative(ndims, shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
|
||||
const int32_t* ndarray_shape,
|
||||
int32_t output_ndims,
|
||||
const int32_t* output_shape) {
|
||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
|
||||
const int64_t* ndarray_shape,
|
||||
int64_t output_ndims,
|
||||
const int64_t* output_shape) {
|
||||
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
|
||||
return size(ndarray);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
|
||||
return size(ndarray);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
|
||||
return nbytes(ndarray);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
|
||||
return nbytes(ndarray);
|
||||
}
|
||||
|
||||
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
||||
return len(ndarray);
|
||||
}
|
||||
|
||||
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
||||
return len(ndarray);
|
||||
}
|
||||
|
||||
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
|
||||
return is_c_contiguous(ndarray);
|
||||
}
|
||||
|
||||
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
|
||||
return is_c_contiguous(ndarray);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
|
||||
return get_nth_pelement(ndarray, nth);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
|
||||
return get_nth_pelement(ndarray, nth);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
|
||||
return get_pelement_by_indices(ndarray, indices);
|
||||
}
|
||||
|
||||
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
|
||||
return get_pelement_by_indices(ndarray, indices);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
|
||||
set_strides_by_shape(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
|
||||
set_strides_by_shape(ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
}
|
165
nac3core/irrt/irrt/ndarray/broadcast.hpp
Normal file
165
nac3core/irrt/irrt/ndarray/broadcast.hpp
Normal file
@ -0,0 +1,165 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
struct ShapeEntry {
|
||||
SizeT ndims;
|
||||
SizeT* shape;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
namespace ndarray::broadcast {
|
||||
/**
|
||||
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
|
||||
*
|
||||
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||
*/
|
||||
template<typename SizeT>
|
||||
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape, SizeT src_ndims, const SizeT* src_shape) {
|
||||
if (src_ndims > target_ndims) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (SizeT i = 0; i < src_ndims; i++) {
|
||||
SizeT target_dim = target_shape[target_ndims - i - 1];
|
||||
SizeT src_dim = src_shape[src_ndims - i - 1];
|
||||
if (!(src_dim == 1 || target_dim == src_dim)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs `np.broadcast_shapes(<shapes>)`
|
||||
*
|
||||
* @param num_shapes Number of entries in `shapes`
|
||||
* @param shapes The list of shape to do `np.broadcast_shapes` on.
|
||||
* @param dst_ndims The length of `dst_shape`.
|
||||
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
|
||||
* for this function since they should already know in order to allocate `dst_shape` in the first place.
|
||||
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
|
||||
* of `np.broadcast_shapes` and write it here.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes, SizeT dst_ndims, SizeT* dst_shape) {
|
||||
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
|
||||
dst_shape[dst_axis] = 1;
|
||||
}
|
||||
|
||||
#ifdef IRRT_DEBUG_ASSERT
|
||||
SizeT max_ndims_found = 0;
|
||||
#endif
|
||||
|
||||
for (SizeT i = 0; i < num_shapes; i++) {
|
||||
ShapeEntry<SizeT> entry = shapes[i];
|
||||
|
||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
||||
debug_assert(SizeT, entry.ndims <= dst_ndims);
|
||||
|
||||
#ifdef IRRT_DEBUG_ASSERT
|
||||
max_ndims_found = max(max_ndims_found, entry.ndims);
|
||||
#endif
|
||||
|
||||
for (SizeT j = 0; j < entry.ndims; j++) {
|
||||
SizeT entry_axis = entry.ndims - j - 1;
|
||||
SizeT dst_axis = dst_ndims - j - 1;
|
||||
|
||||
SizeT entry_dim = entry.shape[entry_axis];
|
||||
SizeT dst_dim = dst_shape[dst_axis];
|
||||
|
||||
if (dst_dim == 1) {
|
||||
dst_shape[dst_axis] = entry_dim;
|
||||
} else if (entry_dim == 1 || entry_dim == dst_dim) {
|
||||
// Do nothing
|
||||
} else {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR,
|
||||
"shape mismatch: objects cannot be broadcast "
|
||||
"to a single shape.",
|
||||
NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef IRRT_DEBUG_ASSERT
|
||||
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
|
||||
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
|
||||
*
|
||||
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
|
||||
* and return the result by modifying `dst_ndarray`.
|
||||
*
|
||||
* # Notes on `dst_ndarray`
|
||||
* The caller is responsible for allocating space for the resulting ndarray.
|
||||
* Here is what this function expects from `dst_ndarray` when called:
|
||||
* - `dst_ndarray->data` does not have to be initialized.
|
||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
|
||||
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
|
||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||
* When this function call ends:
|
||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||
* - `dst_ndarray->ndims` is unchanged.
|
||||
* - `dst_ndarray->shape` is unchanged.
|
||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void broadcast_to(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
if (!ndarray::broadcast::can_broadcast_shape_to(dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
|
||||
src_ndarray->shape)) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "operands could not be broadcast together", NO_PARAM, NO_PARAM,
|
||||
NO_PARAM);
|
||||
}
|
||||
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
|
||||
SizeT src_axis = src_ndarray->ndims - i - 1;
|
||||
SizeT dst_axis = dst_ndarray->ndims - i - 1;
|
||||
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 && dst_ndarray->shape[dst_axis] != 1)) {
|
||||
// Freeze the steps in-place
|
||||
dst_ndarray->strides[dst_axis] = 0;
|
||||
} else {
|
||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace ndarray::broadcast
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::broadcast;
|
||||
|
||||
void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
|
||||
broadcast_to(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
|
||||
broadcast_to(src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
|
||||
const ShapeEntry<int32_t>* shapes,
|
||||
int32_t dst_ndims,
|
||||
int32_t* dst_shape) {
|
||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
|
||||
const ShapeEntry<int64_t>* shapes,
|
||||
int64_t dst_ndims,
|
||||
int64_t* dst_shape) {
|
||||
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
|
||||
}
|
||||
}
|
51
nac3core/irrt/irrt/ndarray/def.hpp
Normal file
51
nac3core/irrt/irrt/ndarray/def.hpp
Normal file
@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* @brief The NDArray object
|
||||
*
|
||||
* Official numpy implementation:
|
||||
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
|
||||
*
|
||||
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
|
||||
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
|
||||
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
|
||||
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
|
||||
* `data`. There are also minor differences in the struct layout.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
struct NDArray {
|
||||
/**
|
||||
* @brief The number of bytes of a single element in `data`.
|
||||
*/
|
||||
SizeT itemsize;
|
||||
|
||||
/**
|
||||
* @brief The number of dimensions of this shape.
|
||||
*/
|
||||
SizeT ndims;
|
||||
|
||||
/**
|
||||
* @brief The NDArray shape, with length equal to `ndims`.
|
||||
*
|
||||
* Note that it may contain 0.
|
||||
*/
|
||||
SizeT* shape;
|
||||
|
||||
/**
|
||||
* @brief Array strides, with length equal to `ndims`
|
||||
*
|
||||
* The stride values are in units of bytes, not number of elements.
|
||||
*
|
||||
* Note that `strides` can have negative values or contain 0.
|
||||
*/
|
||||
SizeT* strides;
|
||||
|
||||
/**
|
||||
* @brief The underlying data this `ndarray` is pointing to.
|
||||
*/
|
||||
void* data;
|
||||
};
|
||||
} // namespace
|
219
nac3core/irrt/irrt/ndarray/indexing.hpp
Normal file
219
nac3core/irrt/irrt/ndarray/indexing.hpp
Normal file
@ -0,0 +1,219 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/basic.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
#include "irrt/range.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
|
||||
namespace {
|
||||
typedef uint8_t NDIndexType;
|
||||
|
||||
/**
|
||||
* @brief A single element index
|
||||
*
|
||||
* `data` points to a `int32_t`.
|
||||
*/
|
||||
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
|
||||
|
||||
/**
|
||||
* @brief A slice index
|
||||
*
|
||||
* `data` points to a `Slice<int32_t>`.
|
||||
*/
|
||||
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
|
||||
|
||||
/**
|
||||
* @brief `np.newaxis` / `None`
|
||||
*
|
||||
* `data` is unused.
|
||||
*/
|
||||
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
|
||||
|
||||
/**
|
||||
* @brief `Ellipsis` / `...`
|
||||
*
|
||||
* `data` is unused.
|
||||
*/
|
||||
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
|
||||
|
||||
/**
|
||||
* @brief An index used in ndarray indexing
|
||||
*
|
||||
* That is:
|
||||
* ```
|
||||
* my_ndarray[::-1, 3, ..., np.newaxis]
|
||||
* ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex.
|
||||
* ```
|
||||
*/
|
||||
struct NDIndex {
|
||||
/**
|
||||
* @brief Enum tag to specify the type of index.
|
||||
*
|
||||
* Please see the comment of each enum constant.
|
||||
*/
|
||||
NDIndexType type;
|
||||
|
||||
/**
|
||||
* @brief The accompanying data associated with `type`.
|
||||
*
|
||||
* Please see the comment of each enum constant.
|
||||
*/
|
||||
uint8_t* data;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
namespace ndarray::indexing {
|
||||
/**
|
||||
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
|
||||
*
|
||||
* This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python.
|
||||
*
|
||||
* This function also does proper assertions on `indices` to check for out of bounds access and more.
|
||||
*
|
||||
* # Notes on `dst_ndarray`
|
||||
* The caller is responsible for allocating space for the resulting ndarray.
|
||||
* Here is what this function expects from `dst_ndarray` when called:
|
||||
* - `dst_ndarray->data` does not have to be initialized.
|
||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
|
||||
* indexing `src_ndarray` with `indices`.
|
||||
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||
* When this function call ends:
|
||||
* - `dst_ndarray->data` is set to `src_ndarray->data`.
|
||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`.
|
||||
* - `dst_ndarray->ndims` is unchanged.
|
||||
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
|
||||
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
|
||||
*
|
||||
* @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python.
|
||||
* @param src_ndarray The NDArray to be indexed.
|
||||
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void index(SizeT num_indices, const NDIndex* indices, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
// Validate `indices`.
|
||||
|
||||
// Expected value of `dst_ndarray->ndims`.
|
||||
SizeT expected_dst_ndims = src_ndarray->ndims;
|
||||
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
|
||||
SizeT num_indexed = 0;
|
||||
// There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis.
|
||||
SizeT num_ellipsis = 0;
|
||||
|
||||
for (SizeT i = 0; i < num_indices; i++) {
|
||||
if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
||||
expected_dst_ndims--;
|
||||
num_indexed++;
|
||||
} else if (indices[i].type == ND_INDEX_TYPE_SLICE) {
|
||||
num_indexed++;
|
||||
} else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) {
|
||||
expected_dst_ndims++;
|
||||
} else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) {
|
||||
num_ellipsis++;
|
||||
if (num_ellipsis > 1) {
|
||||
raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM,
|
||||
NO_PARAM, NO_PARAM);
|
||||
}
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
|
||||
|
||||
if (src_ndarray->ndims - num_indexed < 0) {
|
||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||
"too many indices for array: array is {0}-dimensional, "
|
||||
"but {1} were indexed",
|
||||
src_ndarray->ndims, num_indices, NO_PARAM);
|
||||
}
|
||||
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
// Reference code:
|
||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
||||
SizeT src_axis = 0;
|
||||
SizeT dst_axis = 0;
|
||||
|
||||
for (int32_t i = 0; i < num_indices; i++) {
|
||||
const NDIndex* index = &indices[i];
|
||||
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
|
||||
SizeT input = (SizeT) * ((int32_t*)index->data);
|
||||
|
||||
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
|
||||
if (k == -1) {
|
||||
raise_exception(SizeT, EXN_INDEX_ERROR,
|
||||
"index {0} is out of bounds for axis {1} "
|
||||
"with size {2}",
|
||||
input, src_axis, src_ndarray->shape[src_axis]);
|
||||
}
|
||||
|
||||
dst_ndarray->data = static_cast<uint8_t*>(dst_ndarray->data) + k * src_ndarray->strides[src_axis];
|
||||
|
||||
src_axis++;
|
||||
} else if (index->type == ND_INDEX_TYPE_SLICE) {
|
||||
Slice<int32_t>* slice = (Slice<int32_t>*)index->data;
|
||||
|
||||
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
|
||||
|
||||
dst_ndarray->data =
|
||||
static_cast<uint8_t*>(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis];
|
||||
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
|
||||
dst_ndarray->shape[dst_axis] = (SizeT)range.len<SizeT>();
|
||||
|
||||
dst_axis++;
|
||||
src_axis++;
|
||||
} else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
|
||||
dst_ndarray->strides[dst_axis] = 0;
|
||||
dst_ndarray->shape[dst_axis] = 1;
|
||||
|
||||
dst_axis++;
|
||||
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
|
||||
// The number of ':' entries this '...' implies.
|
||||
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
|
||||
|
||||
for (SizeT j = 0; j < ellipsis_size; j++) {
|
||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
||||
|
||||
dst_axis++;
|
||||
src_axis++;
|
||||
}
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
|
||||
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
|
||||
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
|
||||
}
|
||||
|
||||
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
|
||||
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
|
||||
}
|
||||
} // namespace ndarray::indexing
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::indexing;
|
||||
|
||||
void __nac3_ndarray_index(int32_t num_indices,
|
||||
NDIndex* indices,
|
||||
NDArray<int32_t>* src_ndarray,
|
||||
NDArray<int32_t>* dst_ndarray) {
|
||||
index(num_indices, indices, src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_index64(int64_t num_indices,
|
||||
NDIndex* indices,
|
||||
NDArray<int64_t>* src_ndarray,
|
||||
NDArray<int64_t>* dst_ndarray) {
|
||||
index(num_indices, indices, src_ndarray, dst_ndarray);
|
||||
}
|
||||
}
|
146
nac3core/irrt/irrt/ndarray/iter.hpp
Normal file
146
nac3core/irrt/irrt/ndarray/iter.hpp
Normal file
@ -0,0 +1,146 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* @brief Helper struct to enumerate through an ndarray *efficiently*.
|
||||
*
|
||||
* Example usage (in pseudo-code):
|
||||
* ```
|
||||
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
|
||||
* NDIter nditer;
|
||||
* nditer.initialize(my_ndarray);
|
||||
* while (nditer.has_element()) {
|
||||
* // This body is run 6 (= my_ndarray.size) times.
|
||||
*
|
||||
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
|
||||
* print(nditer.indices);
|
||||
*
|
||||
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
|
||||
* print(nditer.nth);
|
||||
*
|
||||
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
|
||||
* print(*((double *) nditer.element))
|
||||
*
|
||||
* nditer.next(); // Go to next element.
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* Interesting cases:
|
||||
* - If `my_ndarray.ndims` == 0, there is one iteration.
|
||||
* - If `my_ndarray.shape` contains zeroes, there are no iterations.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
struct NDIter {
|
||||
// Information about the ndarray being iterated over.
|
||||
SizeT ndims;
|
||||
SizeT* shape;
|
||||
SizeT* strides;
|
||||
|
||||
/**
|
||||
* @brief The current indices.
|
||||
*
|
||||
* Must be allocated by the caller.
|
||||
*/
|
||||
SizeT* indices;
|
||||
|
||||
/**
|
||||
* @brief The nth (0-based) index of the current indices.
|
||||
*
|
||||
* Initially this is 0.
|
||||
*/
|
||||
SizeT nth;
|
||||
|
||||
/**
|
||||
* @brief Pointer to the current element.
|
||||
*
|
||||
* Initially this points to first element of the ndarray.
|
||||
*/
|
||||
void* element;
|
||||
|
||||
/**
|
||||
* @brief Cache for the product of shape.
|
||||
*
|
||||
* Could be 0 if `shape` has 0s in it.
|
||||
*/
|
||||
SizeT size;
|
||||
|
||||
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, void* element, SizeT* indices) {
|
||||
this->ndims = ndims;
|
||||
this->shape = shape;
|
||||
this->strides = strides;
|
||||
|
||||
this->indices = indices;
|
||||
this->element = element;
|
||||
|
||||
// Compute size
|
||||
this->size = 1;
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
this->size *= shape[i];
|
||||
}
|
||||
|
||||
// `indices` starts on all 0s.
|
||||
for (SizeT axis = 0; axis < ndims; axis++)
|
||||
indices[axis] = 0;
|
||||
nth = 0;
|
||||
}
|
||||
|
||||
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
|
||||
// NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
|
||||
// element as well.
|
||||
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
|
||||
}
|
||||
|
||||
// Is the current iteration valid?
|
||||
// If true, then `element`, `indices` and `nth` contain details about the current element.
|
||||
bool has_element() { return nth < size; }
|
||||
|
||||
// Go to the next element.
|
||||
void next() {
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
SizeT axis = ndims - i - 1;
|
||||
indices[axis]++;
|
||||
if (indices[axis] >= shape[axis]) {
|
||||
indices[axis] = 0;
|
||||
|
||||
// TODO: There is something called backstrides to speedup iteration.
|
||||
// See https://ajcr.net/stride-guide-part-1/, and
|
||||
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
|
||||
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) - strides[axis] * (shape[axis] - 1));
|
||||
} else {
|
||||
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) + strides[axis]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
nth++;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
|
||||
iter->initialize_by_ndarray(ndarray, indices);
|
||||
}
|
||||
|
||||
void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
|
||||
iter->initialize_by_ndarray(ndarray, indices);
|
||||
}
|
||||
|
||||
bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
|
||||
return iter->has_element();
|
||||
}
|
||||
|
||||
bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
|
||||
return iter->has_element();
|
||||
}
|
||||
|
||||
void __nac3_nditer_next(NDIter<int32_t>* iter) {
|
||||
iter->next();
|
||||
}
|
||||
|
||||
void __nac3_nditer_next64(NDIter<int64_t>* iter) {
|
||||
iter->next();
|
||||
}
|
||||
}
|
98
nac3core/irrt/irrt/ndarray/matmul.hpp
Normal file
98
nac3core/irrt/irrt/ndarray/matmul.hpp
Normal file
@ -0,0 +1,98 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/basic.hpp"
|
||||
#include "irrt/ndarray/broadcast.hpp"
|
||||
#include "irrt/ndarray/iter.hpp"
|
||||
|
||||
// NOTE: Everything would be much easier and elegant if einsum is implemented.
|
||||
|
||||
namespace {
|
||||
namespace ndarray::matmul {
|
||||
|
||||
/**
|
||||
* @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`.
|
||||
*
|
||||
* Example:
|
||||
* Suppose `a_shape == [1, 97, 4, 2]`
|
||||
* and `b_shape == [99, 98, 1, 2, 5]`,
|
||||
*
|
||||
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
|
||||
* `new_b_shape == [99, 98, 97, 2, 5]`,
|
||||
* and `dst_shape == [99, 98, 97, 4, 5]`.
|
||||
* ^^^^^^^^^^ ^^^^
|
||||
* (broadcasted) (4x2 @ 2x5 => 4x5)
|
||||
*
|
||||
* @param a_ndims Length of `a_shape`.
|
||||
* @param a_shape Shape of `a`.
|
||||
* @param b_ndims Length of `b_shape`.
|
||||
* @param b_shape Shape of `b`.
|
||||
* @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`,
|
||||
* `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void calculate_shapes(SizeT a_ndims,
|
||||
SizeT* a_shape,
|
||||
SizeT b_ndims,
|
||||
SizeT* b_shape,
|
||||
SizeT final_ndims,
|
||||
SizeT* new_a_shape,
|
||||
SizeT* new_b_shape,
|
||||
SizeT* dst_shape) {
|
||||
debug_assert(SizeT, a_ndims >= 2);
|
||||
debug_assert(SizeT, b_ndims >= 2);
|
||||
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
|
||||
|
||||
// Check that a and b are compatible for matmul
|
||||
if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) {
|
||||
// This is a custom error message. Different from NumPy.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
|
||||
a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM);
|
||||
}
|
||||
|
||||
const SizeT num_entries = 2;
|
||||
ShapeEntry<SizeT> entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape},
|
||||
{.ndims = b_ndims - 2, .shape = b_shape}};
|
||||
|
||||
// TODO: Optimize this
|
||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_a_shape);
|
||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, new_b_shape);
|
||||
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries, final_ndims - 2, dst_shape);
|
||||
|
||||
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
|
||||
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
|
||||
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
|
||||
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
|
||||
}
|
||||
} // namespace ndarray::matmul
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::matmul;
|
||||
|
||||
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims,
|
||||
int32_t* a_shape,
|
||||
int32_t b_ndims,
|
||||
int32_t* b_shape,
|
||||
int32_t final_ndims,
|
||||
int32_t* new_a_shape,
|
||||
int32_t* new_b_shape,
|
||||
int32_t* dst_shape) {
|
||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims,
|
||||
int64_t* a_shape,
|
||||
int64_t b_ndims,
|
||||
int64_t* b_shape,
|
||||
int64_t final_ndims,
|
||||
int64_t* new_a_shape,
|
||||
int64_t* new_b_shape,
|
||||
int64_t* dst_shape) {
|
||||
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape);
|
||||
}
|
||||
}
|
97
nac3core/irrt/irrt/ndarray/reshape.hpp
Normal file
97
nac3core/irrt/irrt/ndarray/reshape.hpp
Normal file
@ -0,0 +1,97 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
|
||||
namespace {
|
||||
namespace ndarray::reshape {
|
||||
/**
|
||||
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
|
||||
*
|
||||
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
|
||||
* modified to contain the resolved dimension.
|
||||
*
|
||||
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
|
||||
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
|
||||
*
|
||||
* @param size The `.size` of `<ndarray>`
|
||||
* @param new_ndims Number of elements in `new_shape`
|
||||
* @param new_shape Target shape to reshape to
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims, SizeT* new_shape) {
|
||||
// Is there a -1 in `new_shape`?
|
||||
bool neg1_exists = false;
|
||||
// Location of -1, only initialized if `neg1_exists` is true
|
||||
SizeT neg1_axis_i;
|
||||
// The computed ndarray size of `new_shape`
|
||||
SizeT new_size = 1;
|
||||
|
||||
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
|
||||
SizeT dim = new_shape[axis_i];
|
||||
if (dim < 0) {
|
||||
if (dim == -1) {
|
||||
if (neg1_exists) {
|
||||
// Multiple `-1` found. Throw an error.
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "can only specify one unknown dimension", NO_PARAM,
|
||||
NO_PARAM, NO_PARAM);
|
||||
} else {
|
||||
neg1_exists = true;
|
||||
neg1_axis_i = axis_i;
|
||||
}
|
||||
} else {
|
||||
// TODO: What? In `np.reshape` any negative dimensions is
|
||||
// treated like its `-1`.
|
||||
//
|
||||
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
|
||||
//
|
||||
// It is not documented by numpy.
|
||||
// Throw an error for now...
|
||||
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "Found non -1 negative dimension {0} on axis {1}", dim, axis_i,
|
||||
NO_PARAM);
|
||||
}
|
||||
} else {
|
||||
new_size *= dim;
|
||||
}
|
||||
}
|
||||
|
||||
bool can_reshape;
|
||||
if (neg1_exists) {
|
||||
// Let `x` be the unknown dimension
|
||||
// Solve `x * <new_size> = <size>`
|
||||
if (new_size == 0 && size == 0) {
|
||||
// `x` has infinitely many solutions
|
||||
can_reshape = false;
|
||||
} else if (new_size == 0 && size != 0) {
|
||||
// `x` has no solutions
|
||||
can_reshape = false;
|
||||
} else if (size % new_size != 0) {
|
||||
// `x` has no integer solutions
|
||||
can_reshape = false;
|
||||
} else {
|
||||
can_reshape = true;
|
||||
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
|
||||
}
|
||||
} else {
|
||||
can_reshape = (new_size == size);
|
||||
}
|
||||
|
||||
if (!can_reshape) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "cannot reshape array of size {0} into given shape", size, NO_PARAM,
|
||||
NO_PARAM);
|
||||
}
|
||||
}
|
||||
} // namespace ndarray::reshape
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
void __nac3_ndarray_reshape_resolve_and_check_new_shape(int32_t size, int32_t new_ndims, int32_t* new_shape) {
|
||||
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_reshape_resolve_and_check_new_shape64(int64_t size, int64_t new_ndims, int64_t* new_shape) {
|
||||
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
|
||||
}
|
||||
}
|
143
nac3core/irrt/irrt/ndarray/transpose.hpp
Normal file
143
nac3core/irrt/irrt/ndarray/transpose.hpp
Normal file
@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/ndarray/def.hpp"
|
||||
#include "irrt/slice.hpp"
|
||||
|
||||
/*
|
||||
* Notes on `np.transpose(<array>, <axes>)`
|
||||
*
|
||||
* TODO: `axes`, if specified, can actually contain negative indices,
|
||||
* but it is not documented in numpy.
|
||||
*
|
||||
* Supporting it for now.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
namespace ndarray::transpose {
|
||||
/**
|
||||
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
|
||||
*
|
||||
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
|
||||
* is specified but the user, use this function to do assertions on it.
|
||||
*
|
||||
* @param ndims The number of dimensions of `<array>`
|
||||
* @param num_axes Number of elements in `<axes>` as specified by the user.
|
||||
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
|
||||
* @param axes The user specified `<axes>`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
|
||||
if (ndims != num_axes) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
}
|
||||
|
||||
// TODO: Optimize this
|
||||
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
|
||||
for (SizeT i = 0; i < ndims; i++)
|
||||
axe_specified[i] = false;
|
||||
|
||||
for (SizeT i = 0; i < ndims; i++) {
|
||||
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
|
||||
if (axis == -1) {
|
||||
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "axis {0} is out of bounds for array of dimension {1}", axis, ndims,
|
||||
NO_PARAM);
|
||||
}
|
||||
|
||||
if (axe_specified[axis]) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "repeated axis in transpose", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
}
|
||||
|
||||
axe_specified[axis] = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
|
||||
*
|
||||
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
|
||||
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
|
||||
*
|
||||
* The transpose view created is returned by modifying `dst_ndarray`.
|
||||
*
|
||||
* The caller is responsible for setting up `dst_ndarray` before calling this function.
|
||||
* Here is what this function expects from `dst_ndarray` when called:
|
||||
* - `dst_ndarray->data` does not have to be initialized.
|
||||
* - `dst_ndarray->itemsize` does not have to be initialized.
|
||||
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
|
||||
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
|
||||
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
|
||||
* When this function call ends:
|
||||
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
|
||||
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
|
||||
* - `dst_ndarray->ndims` is unchanged
|
||||
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
|
||||
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
|
||||
*
|
||||
* @param src_ndarray The NDArray to build a transpose view on
|
||||
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
|
||||
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
|
||||
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray, SizeT num_axes, const SizeT* axes) {
|
||||
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
|
||||
const auto ndims = src_ndarray->ndims;
|
||||
|
||||
if (axes != nullptr)
|
||||
assert_transpose_axes(ndims, num_axes, axes);
|
||||
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
|
||||
if (axes == nullptr) {
|
||||
// `np.transpose(<array>, axes=None)`
|
||||
|
||||
/*
|
||||
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
|
||||
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
|
||||
* is reversing the order of strides and shape.
|
||||
*
|
||||
* This is a fast implementation to handle this special (but very common) case.
|
||||
*/
|
||||
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
|
||||
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
|
||||
}
|
||||
} else {
|
||||
// `np.transpose(<array>, <axes>)`
|
||||
|
||||
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
|
||||
for (SizeT axis = 0; axis < ndims; axis++) {
|
||||
// `i` cannot be OUT_OF_BOUNDS because of assertions
|
||||
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
|
||||
|
||||
dst_ndarray->shape[axis] = src_ndarray->shape[i];
|
||||
dst_ndarray->strides[axis] = src_ndarray->strides[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace ndarray::transpose
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace ndarray::transpose;
|
||||
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
|
||||
NDArray<int32_t>* dst_ndarray,
|
||||
int32_t num_axes,
|
||||
const int32_t* axes) {
|
||||
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
|
||||
NDArray<int64_t>* dst_ndarray,
|
||||
int64_t num_axes,
|
||||
const int64_t* axes) {
|
||||
transpose(src_ndarray, dst_ndarray, num_axes, axes);
|
||||
}
|
||||
}
|
47
nac3core/irrt/irrt/range.hpp
Normal file
47
nac3core/irrt/irrt/range.hpp
Normal file
@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
namespace range {
|
||||
template<typename T>
|
||||
T len(T start, T stop, T step) {
|
||||
// Reference:
|
||||
// https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
|
||||
if (step > 0 && start < stop)
|
||||
return 1 + (stop - 1 - start) / step;
|
||||
else if (step < 0 && start > stop)
|
||||
return 1 + (start - 1 - stop) / (-step);
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
} // namespace range
|
||||
|
||||
/**
|
||||
* @brief A Python range.
|
||||
*/
|
||||
template<typename T>
|
||||
struct Range {
|
||||
T start;
|
||||
T stop;
|
||||
T step;
|
||||
|
||||
/**
|
||||
* @brief Calculate the `len()` of this range.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
T len() {
|
||||
debug_assert(SizeT, step != 0);
|
||||
return range::len(start, stop, step);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
using namespace range;
|
||||
|
||||
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
||||
return len(start, end, step);
|
||||
}
|
||||
}
|
@ -1,6 +1,145 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/debug.hpp"
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/math_util.hpp"
|
||||
#include "irrt/range.hpp"
|
||||
|
||||
namespace {
|
||||
namespace slice {
|
||||
/**
|
||||
* @brief Resolve a possibly negative index in a list of a known length.
|
||||
*
|
||||
* Returns -1 if the resolved index is out of the list's bounds.
|
||||
*/
|
||||
template<typename T>
|
||||
T resolve_index_in_length(T length, T index) {
|
||||
T resolved = index < 0 ? length + index : index;
|
||||
if (0 <= resolved && resolved < length) {
|
||||
return resolved;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Resolve a slice as a range.
|
||||
*
|
||||
* This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python.
|
||||
*/
|
||||
template<typename T>
|
||||
void indices(bool start_defined,
|
||||
T start,
|
||||
bool stop_defined,
|
||||
T stop,
|
||||
bool step_defined,
|
||||
T step,
|
||||
T length,
|
||||
T* range_start,
|
||||
T* range_stop,
|
||||
T* range_step) {
|
||||
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
||||
*range_step = step_defined ? step : 1;
|
||||
bool step_is_negative = *range_step < 0;
|
||||
|
||||
T lower, upper;
|
||||
if (step_is_negative) {
|
||||
lower = -1;
|
||||
upper = length - 1;
|
||||
} else {
|
||||
lower = 0;
|
||||
upper = length;
|
||||
}
|
||||
|
||||
if (start_defined) {
|
||||
*range_start = start < 0 ? max(lower, start + length) : min(upper, start);
|
||||
} else {
|
||||
*range_start = step_is_negative ? upper : lower;
|
||||
}
|
||||
|
||||
if (stop_defined) {
|
||||
*range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
|
||||
} else {
|
||||
*range_stop = step_is_negative ? lower : upper;
|
||||
}
|
||||
}
|
||||
} // namespace slice
|
||||
|
||||
/**
|
||||
* @brief A Python-like slice with **unresolved** indices.
|
||||
*/
|
||||
template<typename T>
|
||||
struct Slice {
|
||||
bool start_defined;
|
||||
T start;
|
||||
|
||||
bool stop_defined;
|
||||
T stop;
|
||||
|
||||
bool step_defined;
|
||||
T step;
|
||||
|
||||
Slice() { this->reset(); }
|
||||
|
||||
void reset() {
|
||||
this->start_defined = false;
|
||||
this->stop_defined = false;
|
||||
this->step_defined = false;
|
||||
}
|
||||
|
||||
void set_start(T start) {
|
||||
this->start_defined = true;
|
||||
this->start = start;
|
||||
}
|
||||
|
||||
void set_stop(T stop) {
|
||||
this->stop_defined = true;
|
||||
this->stop = stop;
|
||||
}
|
||||
|
||||
void set_step(T step) {
|
||||
this->step_defined = true;
|
||||
this->step = step;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Resolve this slice as a range.
|
||||
*
|
||||
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
Range<T> indices(T length) {
|
||||
// Reference:
|
||||
// https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
|
||||
debug_assert(SizeT, length >= 0);
|
||||
|
||||
Range<T> result;
|
||||
slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start,
|
||||
&result.stop, &result.step);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Like `.indices()` but with assertions.
|
||||
*/
|
||||
template<typename SizeT>
|
||||
Range<T> indices_checked(T length) {
|
||||
// TODO: Switch to `SizeT length`
|
||||
|
||||
if (length < 0) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
|
||||
NO_PARAM);
|
||||
}
|
||||
|
||||
if (this->step_defined && this->step == 0) {
|
||||
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
|
||||
}
|
||||
|
||||
return this->indices<SizeT>(length);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||
@ -14,15 +153,4 @@ SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
||||
SliceIndex diff = end - start;
|
||||
if (diff > 0 && step > 0) {
|
||||
return ((diff - 1) / step) + 1;
|
||||
} else if (diff < 0 && step < 0) {
|
||||
return ((diff + 1) / step) + 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
23
nac3core/irrt/irrt/string.hpp
Normal file
23
nac3core/irrt/irrt/string.hpp
Normal file
@ -0,0 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
bool __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
|
||||
if (len1 != len2) {
|
||||
return 0;
|
||||
}
|
||||
return __builtin_memcmp(str1, str2, static_cast<SizeT>(len1)) == 0;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
bool nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) {
|
||||
return __nac3_str_eq_impl<uint32_t>(str1, len1, str2, len2);
|
||||
}
|
||||
|
||||
bool nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
|
||||
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
|
||||
}
|
||||
}
|
21
nac3core/nac3core_derive/Cargo.toml
Normal file
21
nac3core/nac3core_derive/Cargo.toml
Normal file
@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "nac3core_derive"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[[test]]
|
||||
name = "structfields_tests"
|
||||
path = "tests/structfields_test.rs"
|
||||
|
||||
[dev-dependencies]
|
||||
nac3core = { path = ".." }
|
||||
trybuild = { version = "1.0", features = ["diff"] }
|
||||
|
||||
[dependencies]
|
||||
proc-macro2 = "1.0"
|
||||
proc-macro-error = "1.0"
|
||||
syn = "2.0"
|
||||
quote = "1.0"
|
320
nac3core/nac3core_derive/src/lib.rs
Normal file
320
nac3core/nac3core_derive/src/lib.rs
Normal file
@ -0,0 +1,320 @@
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro_error::{abort, proc_macro_error};
|
||||
use quote::quote;
|
||||
use syn::{
|
||||
parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall,
|
||||
ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath,
|
||||
};
|
||||
|
||||
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
|
||||
///
|
||||
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
|
||||
/// `expected_ty_name`, otherwise returns [`None`].
|
||||
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
|
||||
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let segments = &path.segments;
|
||||
if segments.len() != 1 {
|
||||
return None;
|
||||
};
|
||||
|
||||
let segment = segments.iter().next().unwrap();
|
||||
if segment.ident != expected_ty_name {
|
||||
return None;
|
||||
}
|
||||
|
||||
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
|
||||
return Some(Vec::new());
|
||||
};
|
||||
let args = &path_args.args;
|
||||
|
||||
Some(args.iter().cloned().collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`].
|
||||
fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option<Ident> {
|
||||
path.require_ident()
|
||||
.ok()
|
||||
.filter(|ident| target_idents.iter().any(|target| ident == target))
|
||||
.map(|ident| Ident::new(replacement, ident.span()))
|
||||
}
|
||||
|
||||
/// Extracts the left-hand side of a dot-expression.
|
||||
fn extract_dot_operand(expr: &Expr) -> Option<&Expr> {
|
||||
match expr {
|
||||
Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
|
||||
| Expr::Field(ExprField { base: operand, .. }) => Some(operand),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the
|
||||
/// replacement is performed.
|
||||
///
|
||||
/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`.
|
||||
fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> {
|
||||
if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
|
||||
| Expr::Field(ExprField { base: operand, .. }) = expr
|
||||
{
|
||||
return if extract_dot_operand(operand).is_some() {
|
||||
if replace_top_level_receiver(operand, ident).is_some() {
|
||||
Some(expr)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
*operand = Box::new(Expr::Path(ExprPath {
|
||||
attrs: Vec::default(),
|
||||
qself: None,
|
||||
path: ident.into(),
|
||||
}));
|
||||
|
||||
Some(expr)
|
||||
};
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all
|
||||
/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`].
|
||||
///
|
||||
/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will
|
||||
/// return `vec![c, b, a]`.
|
||||
fn iter_dot_operands(expr: &Expr) -> impl Iterator<Item = &Expr> {
|
||||
let mut o = extract_dot_operand(expr);
|
||||
|
||||
std::iter::from_fn(move || {
|
||||
let this = o;
|
||||
o = o.as_ref().and_then(|o| extract_dot_operand(o));
|
||||
|
||||
this
|
||||
})
|
||||
}
|
||||
|
||||
/// Normalizes a value expression for use when creating an instance of this structure, returning a
|
||||
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
|
||||
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
|
||||
match &expr {
|
||||
Expr::Path(ExprPath { qself: None, path, .. }) => {
|
||||
if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") {
|
||||
quote! { #ident }
|
||||
} else {
|
||||
abort!(
|
||||
path,
|
||||
format!(
|
||||
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
||||
quote!(#expr).to_string(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
Expr::Call(_) => {
|
||||
quote! { ctx.#expr }
|
||||
}
|
||||
|
||||
Expr::MethodCall(_) => {
|
||||
let base_receiver = iter_dot_operands(expr).last();
|
||||
|
||||
match base_receiver {
|
||||
// `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize`
|
||||
Some(Expr::Path(ExprPath { qself: None, path, .. }))
|
||||
if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() =>
|
||||
{
|
||||
let ident =
|
||||
map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap();
|
||||
|
||||
let mut expr = expr.clone();
|
||||
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
|
||||
|
||||
quote!(#expr)
|
||||
}
|
||||
|
||||
// `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx`
|
||||
Some(Expr::Path(ExprPath { qself: None, path, .. }))
|
||||
if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() =>
|
||||
{
|
||||
let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap();
|
||||
|
||||
let mut expr = expr.clone();
|
||||
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
|
||||
|
||||
quote!(#expr)
|
||||
}
|
||||
|
||||
// No reserved identifier prefix -> Prepend `ctx.` to the entire expression
|
||||
_ => quote! { ctx.#expr },
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
abort!(
|
||||
expr,
|
||||
format!(
|
||||
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
|
||||
quote!(#expr).to_string(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Derives an implementation of `codegen::types::structure::StructFields`.
|
||||
///
|
||||
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
|
||||
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
|
||||
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
|
||||
///
|
||||
/// # Prerequisites
|
||||
///
|
||||
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
|
||||
/// `StructFields`.
|
||||
///
|
||||
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
|
||||
/// with either `StructField` or [`PhantomData`] types.
|
||||
///
|
||||
/// # Attributes for [`StructFields`]
|
||||
///
|
||||
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
|
||||
/// accepts one of the following:
|
||||
///
|
||||
/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`).
|
||||
/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`.
|
||||
/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent
|
||||
/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g.
|
||||
/// `usize.array_type(3)`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use nac3core::{
|
||||
/// codegen::types::structure::StructField,
|
||||
/// inkwell::{
|
||||
/// values::{IntValue, PointerValue},
|
||||
/// AddressSpace,
|
||||
/// },
|
||||
/// };
|
||||
/// use nac3core_derive::StructFields;
|
||||
///
|
||||
/// // All classes that implement StructFields must also implement Eq and Copy
|
||||
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
/// pub struct SliceValue<'ctx> {
|
||||
/// // Declares ptr have a value type of i8*
|
||||
/// //
|
||||
/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)`
|
||||
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
///
|
||||
/// // Declares len have a value type of usize, depending on the target compilation platform
|
||||
/// #[value_type(usize)]
|
||||
/// len: StructField<'ctx, IntValue<'ctx>>,
|
||||
/// }
|
||||
/// ```
|
||||
#[proc_macro_derive(StructFields, attributes(value_type))]
|
||||
#[proc_macro_error]
|
||||
pub fn derive(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as syn::DeriveInput);
|
||||
let ident = &input.ident;
|
||||
|
||||
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
|
||||
abort!(input, "Only structs with named fields are supported");
|
||||
};
|
||||
if let Err(err_span) =
|
||||
fields
|
||||
.iter()
|
||||
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
|
||||
{
|
||||
abort!(err_span, "Only structs with named fields are supported");
|
||||
};
|
||||
|
||||
// Check if struct<'ctx>
|
||||
if input.generics.params.len() != 1 {
|
||||
abort!(input.generics, "Expected exactly 1 generic parameter")
|
||||
}
|
||||
|
||||
let phantom_info = fields
|
||||
.iter()
|
||||
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
|
||||
.map(|field| field.ident.as_ref().unwrap())
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let field_info = fields
|
||||
.iter()
|
||||
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
|
||||
.map(|field| {
|
||||
let ident = field.ident.as_ref().unwrap();
|
||||
let ty = &field.ty;
|
||||
|
||||
let Some(_) = extract_generic_args("StructField", ty) else {
|
||||
abort!(field, "Only StructField and PhantomData are allowed")
|
||||
};
|
||||
|
||||
let attrs = &field.attrs;
|
||||
let Some(value_type_attr) =
|
||||
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
|
||||
else {
|
||||
abort!(field, "Expected #[value_type(...)] attribute for field");
|
||||
};
|
||||
|
||||
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
|
||||
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
|
||||
};
|
||||
|
||||
let value_expr_toks = normalize_value_expr(&value_type_expr);
|
||||
|
||||
(ident.clone(), value_expr_toks)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
|
||||
let phantoms_create = phantom_info
|
||||
.iter()
|
||||
.map(|id| quote! { #id: ::std::marker::PhantomData })
|
||||
.collect::<Vec<_>>();
|
||||
let fields_create = field_info
|
||||
.iter()
|
||||
.map(|(id, ty)| {
|
||||
let id_lit = LitStr::new(&id.to_string(), id.span());
|
||||
quote! {
|
||||
#id: ::nac3core::codegen::types::structure::StructField::create(
|
||||
&mut counter,
|
||||
#id_lit,
|
||||
#ty,
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// `.into()` impl of `StructField` for `StructFields::to_vec`
|
||||
let fields_into =
|
||||
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
|
||||
|
||||
let impl_block = quote! {
|
||||
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
|
||||
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
|
||||
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
|
||||
|
||||
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
|
||||
|
||||
#ident {
|
||||
#(#fields_create),*
|
||||
#(#phantoms_create),*
|
||||
}
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
|
||||
vec![
|
||||
#(#fields_into),*
|
||||
]
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
impl_block.into()
|
||||
}
|
9
nac3core/nac3core_derive/tests/structfields_empty.rs
Normal file
9
nac3core/nac3core_derive/tests/structfields_empty.rs
Normal file
@ -0,0 +1,9 @@
|
||||
use nac3core_derive::StructFields;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct EmptyValue<'ctx> {
|
||||
_phantom: PhantomData<&'ctx ()>,
|
||||
}
|
||||
|
||||
fn main() {}
|
20
nac3core/nac3core_derive/tests/structfields_ndarray.rs
Normal file
20
nac3core/nac3core_derive/tests/structfields_ndarray.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use nac3core::{
|
||||
codegen::types::structure::StructField,
|
||||
inkwell::{
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
},
|
||||
};
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct NDArrayValue<'ctx> {
|
||||
#[value_type(usize)]
|
||||
ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
data: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
fn main() {}
|
18
nac3core/nac3core_derive/tests/structfields_slice.rs
Normal file
18
nac3core/nac3core_derive/tests/structfields_slice.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use nac3core::{
|
||||
codegen::types::structure::StructField,
|
||||
inkwell::{
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
},
|
||||
};
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct SliceValue<'ctx> {
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
len: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
fn main() {}
|
18
nac3core/nac3core_derive/tests/structfields_slice_context.rs
Normal file
18
nac3core/nac3core_derive/tests/structfields_slice_context.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use nac3core::{
|
||||
codegen::types::structure::StructField,
|
||||
inkwell::{
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
},
|
||||
};
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct SliceValue<'ctx> {
|
||||
#[value_type(context.i8_type().ptr_type(AddressSpace::default()))]
|
||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
len: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
fn main() {}
|
18
nac3core/nac3core_derive/tests/structfields_slice_ctx.rs
Normal file
18
nac3core/nac3core_derive/tests/structfields_slice_ctx.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use nac3core::{
|
||||
codegen::types::structure::StructField,
|
||||
inkwell::{
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
},
|
||||
};
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct SliceValue<'ctx> {
|
||||
#[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))]
|
||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
len: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
fn main() {}
|
18
nac3core/nac3core_derive/tests/structfields_slice_sizet.rs
Normal file
18
nac3core/nac3core_derive/tests/structfields_slice_sizet.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use nac3core::{
|
||||
codegen::types::structure::StructField,
|
||||
inkwell::{
|
||||
values::{IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
},
|
||||
};
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct SliceValue<'ctx> {
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(size_t)]
|
||||
len: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
fn main() {}
|
10
nac3core/nac3core_derive/tests/structfields_test.rs
Normal file
10
nac3core/nac3core_derive/tests/structfields_test.rs
Normal file
@ -0,0 +1,10 @@
|
||||
#[test]
|
||||
fn test_parse_empty() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.pass("tests/structfields_empty.rs");
|
||||
t.pass("tests/structfields_slice.rs");
|
||||
t.pass("tests/structfields_slice_ctx.rs");
|
||||
t.pass("tests/structfields_slice_context.rs");
|
||||
t.pass("tests/structfields_slice_sizet.rs");
|
||||
t.pass("tests/structfields_ndarray.rs");
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum {
|
||||
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||
params: IndexMap<TypeVarId, ConcreteType>,
|
||||
},
|
||||
TModule {
|
||||
module_id: DefinitionId,
|
||||
methods: HashMap<StrRef, (ConcreteType, bool)>,
|
||||
},
|
||||
TVirtual {
|
||||
ty: ConcreteType,
|
||||
},
|
||||
@ -205,6 +209,19 @@ impl ConcreteTypeStore {
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule {
|
||||
module_id: *module_id,
|
||||
methods: attributes
|
||||
.iter()
|
||||
.filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) {
|
||||
TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None,
|
||||
_ => Some((
|
||||
*name,
|
||||
(self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1),
|
||||
)),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual {
|
||||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||
},
|
||||
@ -284,6 +301,15 @@ impl ConcreteTypeStore {
|
||||
TypeVar { id, ty }
|
||||
})),
|
||||
},
|
||||
ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule {
|
||||
module_id: *module_id,
|
||||
attributes: methods
|
||||
.iter()
|
||||
.map(|(name, cty)| {
|
||||
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
|
||||
})
|
||||
.collect::<HashMap<_, _>>(),
|
||||
},
|
||||
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
|
||||
args: args
|
||||
.iter()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,12 +1,13 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
targets::TargetMachine,
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
};
|
||||
|
||||
use nac3parser::ast::{Expr, Stmt, StrRef};
|
||||
|
||||
use super::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext};
|
||||
use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext};
|
||||
use crate::{
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{DefinitionId, TopLevelDef},
|
||||
@ -17,6 +18,10 @@ pub trait CodeGenerator {
|
||||
/// Return the module name for the code generator.
|
||||
fn get_name(&self) -> &str;
|
||||
|
||||
/// Return an instance of [`IntType`] corresponding to the type of `size_t` for this instance.
|
||||
///
|
||||
/// Prefer using [`CodeGenContext::get_size_type`] if [`CodeGenContext`] is available, as it is
|
||||
/// equivalent to this function in a more concise syntax.
|
||||
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx>;
|
||||
|
||||
/// Generate function call and returns the function return value.
|
||||
@ -269,19 +274,27 @@ pub struct DefaultCodeGenerator {
|
||||
|
||||
impl DefaultCodeGenerator {
|
||||
#[must_use]
|
||||
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
|
||||
assert!(matches!(size_t, 32 | 64));
|
||||
DefaultCodeGenerator { name, size_t }
|
||||
pub fn new(name: String, size_t: IntType<'_>) -> DefaultCodeGenerator {
|
||||
assert!(matches!(size_t.get_bit_width(), 32 | 64));
|
||||
DefaultCodeGenerator { name, size_t: size_t.get_bit_width() }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_target_machine(
|
||||
name: String,
|
||||
ctx: &Context,
|
||||
target_machine: &TargetMachine,
|
||||
) -> DefaultCodeGenerator {
|
||||
let llvm_usize = ctx.ptr_sized_int_type(&target_machine.get_target_data(), None);
|
||||
Self::new(name, llvm_usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl CodeGenerator for DefaultCodeGenerator {
|
||||
/// Returns the name for this [`CodeGenerator`].
|
||||
fn get_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns an LLVM integer type representing `size_t`.
|
||||
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
|
||||
// it should be unsigned, but we don't really need unsigned and this could save us from
|
||||
// having to do a bit cast...
|
||||
|
174
nac3core/src/codegen/irrt/list.rs
Normal file
174
nac3core/src/codegen/irrt/list.rs
Normal file
@ -0,0 +1,174 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use super::calculate_len_for_slice_range;
|
||||
use crate::codegen::{
|
||||
macros::codegen_unreachable,
|
||||
values::{ArrayLikeValue, ListValue},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// This function handles 'end' **inclusively**.
|
||||
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
|
||||
/// Negative index should be handled before entering this function
|
||||
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
dest_arr: ListValue<'ctx>,
|
||||
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
src_arr: ListValue<'ctx>,
|
||||
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
assert_eq!(dest_idx.0.get_type(), llvm_i32);
|
||||
assert_eq!(dest_idx.1.get_type(), llvm_i32);
|
||||
assert_eq!(dest_idx.2.get_type(), llvm_i32);
|
||||
assert_eq!(src_idx.0.get_type(), llvm_i32);
|
||||
assert_eq!(src_idx.1.get_type(), llvm_i32);
|
||||
assert_eq!(src_idx.2.get_type(), llvm_i32);
|
||||
|
||||
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", llvm_pi8);
|
||||
let slice_assign_fun = {
|
||||
let ty_vec = vec![
|
||||
llvm_i32.into(), // dest start idx
|
||||
llvm_i32.into(), // dest end idx
|
||||
llvm_i32.into(), // dest step
|
||||
elem_ptr_type.into(), // dest arr ptr
|
||||
llvm_i32.into(), // dest arr len
|
||||
llvm_i32.into(), // src start idx
|
||||
llvm_i32.into(), // src end idx
|
||||
llvm_i32.into(), // src step
|
||||
elem_ptr_type.into(), // src arr ptr
|
||||
llvm_i32.into(), // src arr len
|
||||
llvm_i32.into(), // size
|
||||
];
|
||||
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
|
||||
let fn_t = llvm_i32.fn_type(ty_vec.as_slice(), false);
|
||||
ctx.module.add_function(fun_symbol, fn_t, None)
|
||||
})
|
||||
};
|
||||
|
||||
let zero = llvm_i32.const_zero();
|
||||
let one = llvm_i32.const_int(1, false);
|
||||
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
|
||||
let dest_arr_ptr =
|
||||
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
|
||||
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
|
||||
let dest_len =
|
||||
ctx.builder.build_int_truncate_or_bit_cast(dest_len, llvm_i32, "srclen32").unwrap();
|
||||
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
|
||||
let src_arr_ptr =
|
||||
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
|
||||
let src_len = src_arr.load_size(ctx, Some("src.len"));
|
||||
let src_len =
|
||||
ctx.builder.build_int_truncate_or_bit_cast(src_len, llvm_i32, "srclen32").unwrap();
|
||||
|
||||
// index in bound and positive should be done
|
||||
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
||||
// throw exception if not satisfied
|
||||
let src_end = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
|
||||
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
|
||||
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let dest_end = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
|
||||
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
|
||||
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let src_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
||||
let dest_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
||||
let src_eq_dest = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
|
||||
.unwrap();
|
||||
let src_slt_dest = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
|
||||
.unwrap();
|
||||
let dest_step_eq_one = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
dest_idx.2,
|
||||
dest_idx.2.get_type().const_int(1, false),
|
||||
"slice_dest_step_eq_one",
|
||||
)
|
||||
.unwrap();
|
||||
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
|
||||
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
cond,
|
||||
"0:ValueError",
|
||||
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
|
||||
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let new_len = {
|
||||
let args = vec![
|
||||
dest_idx.0.into(), // dest start idx
|
||||
dest_idx.1.into(), // dest end idx
|
||||
dest_idx.2.into(), // dest step
|
||||
dest_arr_ptr.into(), // dest arr ptr
|
||||
dest_len.into(), // dest arr len
|
||||
src_idx.0.into(), // src start idx
|
||||
src_idx.1.into(), // src end idx
|
||||
src_idx.2.into(), // src step
|
||||
src_arr_ptr.into(), // src arr ptr
|
||||
src_len.into(), // src arr len
|
||||
{
|
||||
let s = match ty {
|
||||
BasicTypeEnum::FloatType(t) => t.size_of(),
|
||||
BasicTypeEnum::IntType(t) => t.size_of(),
|
||||
BasicTypeEnum::PointerType(t) => t.size_of(),
|
||||
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
||||
_ => codegen_unreachable!(ctx),
|
||||
};
|
||||
ctx.builder.build_int_truncate_or_bit_cast(s, llvm_i32, "size").unwrap()
|
||||
}
|
||||
.into(),
|
||||
];
|
||||
ctx.builder
|
||||
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
// update length
|
||||
let need_update =
|
||||
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
|
||||
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let update_bb = ctx.ctx.append_basic_block(current, "update");
|
||||
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
||||
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
|
||||
ctx.builder.position_at_end(update_bb);
|
||||
let new_len =
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(new_len, llvm_usize, "new_len").unwrap();
|
||||
dest_arr.store_size(ctx, new_len);
|
||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
}
|
168
nac3core/src/codegen/irrt/math.rs
Normal file
168
nac3core/src/codegen/irrt/math.rs
Normal file
@ -0,0 +1,168 @@
|
||||
use inkwell::{
|
||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||
IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use crate::codegen::{
|
||||
macros::codegen_unreachable,
|
||||
{CodeGenContext, CodeGenerator},
|
||||
};
|
||||
|
||||
// repeated squaring method adapted from GNU Scientific Library:
|
||||
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
base: IntValue<'ctx>,
|
||||
exp: IntValue<'ctx>,
|
||||
signed: bool,
|
||||
) -> IntValue<'ctx> {
|
||||
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
|
||||
(32, 32, true) => "__nac3_int_exp_int32_t",
|
||||
(64, 64, true) => "__nac3_int_exp_int64_t",
|
||||
(32, 32, false) => "__nac3_int_exp_uint32_t",
|
||||
(64, 64, false) => "__nac3_int_exp_uint64_t",
|
||||
_ => codegen_unreachable!(ctx),
|
||||
};
|
||||
let base_type = base.get_type();
|
||||
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
|
||||
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
|
||||
ctx.module.add_function(symbol, fn_type, None)
|
||||
});
|
||||
// throw exception when exp < 0
|
||||
let ge_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
exp,
|
||||
exp.get_type().const_zero(),
|
||||
"assert_int_pow_ge_0",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ge_zero,
|
||||
"0:ValueError",
|
||||
"integer power must be positive or zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
ctx.builder
|
||||
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
|
||||
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
v: FloatValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
assert_eq!(v.get_type(), llvm_f64);
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
|
||||
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_isinf", fn_type, None)
|
||||
});
|
||||
|
||||
let ret = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "isinf")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
generator.bool_to_i1(ctx, ret)
|
||||
}
|
||||
|
||||
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
|
||||
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
v: FloatValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
assert_eq!(v.get_type(), llvm_f64);
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
|
||||
let fn_type = llvm_i32.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_isnan", fn_type, None)
|
||||
});
|
||||
|
||||
let ret = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "isnan")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
generator.bool_to_i1(ctx, ret)
|
||||
}
|
||||
|
||||
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
assert_eq!(v.get_type(), llvm_f64);
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
assert_eq!(v.get_type(), llvm_f64);
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
assert_eq!(v.get_type(), llvm_f64);
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_j0", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "j0")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
@ -3,25 +3,26 @@ use inkwell::{
|
||||
context::Context,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::Module,
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||
AddressSpace, IntPredicate,
|
||||
values::{BasicValue, BasicValueEnum, IntValue},
|
||||
IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use nac3parser::ast::Expr;
|
||||
|
||||
use super::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||
},
|
||||
llvm_intrinsics,
|
||||
macros::codegen_unreachable,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use super::{CodeGenContext, CodeGenerator};
|
||||
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
||||
pub use list::*;
|
||||
pub use math::*;
|
||||
pub use range::*;
|
||||
pub use slice::*;
|
||||
pub use string::*;
|
||||
|
||||
mod list;
|
||||
mod math;
|
||||
pub mod ndarray;
|
||||
mod range;
|
||||
mod slice;
|
||||
mod string;
|
||||
|
||||
#[must_use]
|
||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||
@ -62,86 +63,21 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
||||
irrt_mod
|
||||
}
|
||||
|
||||
// repeated squaring method adapted from GNU Scientific Library:
|
||||
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
base: IntValue<'ctx>,
|
||||
exp: IntValue<'ctx>,
|
||||
signed: bool,
|
||||
) -> IntValue<'ctx> {
|
||||
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
|
||||
(32, 32, true) => "__nac3_int_exp_int32_t",
|
||||
(64, 64, true) => "__nac3_int_exp_int64_t",
|
||||
(32, 32, false) => "__nac3_int_exp_uint32_t",
|
||||
(64, 64, false) => "__nac3_int_exp_uint64_t",
|
||||
_ => codegen_unreachable!(ctx),
|
||||
};
|
||||
let base_type = base.get_type();
|
||||
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
|
||||
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
|
||||
ctx.module.add_function(symbol, fn_type, None)
|
||||
});
|
||||
// throw exception when exp < 0
|
||||
let ge_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
exp,
|
||||
exp.get_type().const_zero(),
|
||||
"assert_int_pow_ge_0",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ge_zero,
|
||||
"0:ValueError",
|
||||
"integer power must be positive or zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
ctx.builder
|
||||
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
start: IntValue<'ctx>,
|
||||
end: IntValue<'ctx>,
|
||||
step: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
const SYMBOL: &str = "__nac3_range_slice_len";
|
||||
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||
let i32_t = ctx.ctx.i32_type();
|
||||
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
|
||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||
});
|
||||
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
"0:ValueError",
|
||||
"step must not be zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
ctx.builder
|
||||
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
/// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
|
||||
///
|
||||
/// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
|
||||
/// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
|
||||
#[must_use]
|
||||
pub fn get_usize_dependent_function_name(ctx: &CodeGenContext<'_, '_>, name: &str) -> String {
|
||||
let mut name = name.to_owned();
|
||||
match ctx.get_size_type().get_bit_width() {
|
||||
32 => {}
|
||||
64 => name.push_str("64"),
|
||||
bit_width => {
|
||||
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
|
||||
}
|
||||
}
|
||||
name
|
||||
}
|
||||
|
||||
/// NOTE: the output value of the end index of this function should be compared ***inclusively***,
|
||||
@ -192,10 +128,11 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
length: IntValue<'ctx>,
|
||||
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
let one = int32.const_int(1, false);
|
||||
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
let zero = llvm_i32.const_zero();
|
||||
let one = llvm_i32.const_int(1, false);
|
||||
let length = ctx.builder.build_int_truncate_or_bit_cast(length, llvm_i32, "leni32").unwrap();
|
||||
Ok(Some(match (start, end, step) {
|
||||
(s, e, None) => (
|
||||
if let Some(s) = s.as_ref() {
|
||||
@ -204,7 +141,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
||||
None => return Ok(None),
|
||||
}
|
||||
} else {
|
||||
int32.const_zero()
|
||||
llvm_i32.const_zero()
|
||||
},
|
||||
{
|
||||
let e = if let Some(s) = e.as_ref() {
|
||||
@ -309,644 +246,3 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
/// this function allows index out of range, since python
|
||||
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
|
||||
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
|
||||
i: &Expr<Option<Type>>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
length: IntValue<'ctx>,
|
||||
) -> Result<Option<IntValue<'ctx>>, String> {
|
||||
const SYMBOL: &str = "__nac3_slice_index_bound";
|
||||
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||
let i32_t = ctx.ctx.i32_type();
|
||||
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
|
||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||
});
|
||||
|
||||
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
|
||||
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
Ok(Some(
|
||||
ctx.builder
|
||||
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
|
||||
/// This function handles 'end' **inclusively**.
|
||||
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
|
||||
/// Negative index should be handled before entering this function
|
||||
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
dest_arr: ListValue<'ctx>,
|
||||
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
src_arr: ListValue<'ctx>,
|
||||
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
|
||||
) {
|
||||
let size_ty = generator.get_size_type(ctx.ctx);
|
||||
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
|
||||
let slice_assign_fun = {
|
||||
let ty_vec = vec![
|
||||
int32.into(), // dest start idx
|
||||
int32.into(), // dest end idx
|
||||
int32.into(), // dest step
|
||||
elem_ptr_type.into(), // dest arr ptr
|
||||
int32.into(), // dest arr len
|
||||
int32.into(), // src start idx
|
||||
int32.into(), // src end idx
|
||||
int32.into(), // src step
|
||||
elem_ptr_type.into(), // src arr ptr
|
||||
int32.into(), // src arr len
|
||||
int32.into(), // size
|
||||
];
|
||||
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
|
||||
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
|
||||
ctx.module.add_function(fun_symbol, fn_t, None)
|
||||
})
|
||||
};
|
||||
|
||||
let zero = int32.const_zero();
|
||||
let one = int32.const_int(1, false);
|
||||
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
|
||||
let dest_arr_ptr =
|
||||
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
|
||||
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
|
||||
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
|
||||
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
|
||||
let src_arr_ptr =
|
||||
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
|
||||
let src_len = src_arr.load_size(ctx, Some("src.len"));
|
||||
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
|
||||
|
||||
// index in bound and positive should be done
|
||||
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
|
||||
// throw exception if not satisfied
|
||||
let src_end = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
|
||||
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
|
||||
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let dest_end = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
|
||||
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
|
||||
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
|
||||
"final_e",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
let src_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
|
||||
let dest_slice_len =
|
||||
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
|
||||
let src_eq_dest = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
|
||||
.unwrap();
|
||||
let src_slt_dest = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
|
||||
.unwrap();
|
||||
let dest_step_eq_one = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
dest_idx.2,
|
||||
dest_idx.2.get_type().const_int(1, false),
|
||||
"slice_dest_step_eq_one",
|
||||
)
|
||||
.unwrap();
|
||||
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
|
||||
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
cond,
|
||||
"0:ValueError",
|
||||
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
|
||||
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let new_len = {
|
||||
let args = vec![
|
||||
dest_idx.0.into(), // dest start idx
|
||||
dest_idx.1.into(), // dest end idx
|
||||
dest_idx.2.into(), // dest step
|
||||
dest_arr_ptr.into(), // dest arr ptr
|
||||
dest_len.into(), // dest arr len
|
||||
src_idx.0.into(), // src start idx
|
||||
src_idx.1.into(), // src end idx
|
||||
src_idx.2.into(), // src step
|
||||
src_arr_ptr.into(), // src arr ptr
|
||||
src_len.into(), // src arr len
|
||||
{
|
||||
let s = match ty {
|
||||
BasicTypeEnum::FloatType(t) => t.size_of(),
|
||||
BasicTypeEnum::IntType(t) => t.size_of(),
|
||||
BasicTypeEnum::PointerType(t) => t.size_of(),
|
||||
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
||||
_ => codegen_unreachable!(ctx),
|
||||
};
|
||||
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
|
||||
}
|
||||
.into(),
|
||||
];
|
||||
ctx.builder
|
||||
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
};
|
||||
// update length
|
||||
let need_update =
|
||||
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
|
||||
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
let update_bb = ctx.ctx.append_basic_block(current, "update");
|
||||
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
|
||||
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
|
||||
ctx.builder.position_at_end(update_bb);
|
||||
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
|
||||
dest_arr.store_size(ctx, generator, new_len);
|
||||
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
}
|
||||
|
||||
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
|
||||
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
v: FloatValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||
ctx.module.add_function("__nac3_isinf", fn_type, None)
|
||||
});
|
||||
|
||||
let ret = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "isinf")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
generator.bool_to_i1(ctx, ret)
|
||||
}
|
||||
|
||||
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
|
||||
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
v: FloatValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
|
||||
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
|
||||
ctx.module.add_function("__nac3_isnan", fn_type, None)
|
||||
});
|
||||
|
||||
let ret = ctx
|
||||
.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "isnan")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
generator.bool_to_i1(ctx, ret)
|
||||
}
|
||||
|
||||
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_gamma", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "gamma")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_gammaln", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "gammaln")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
|
||||
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
|
||||
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
|
||||
ctx.module.add_function("__nac3_j0", fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[v.into()], "j0")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
|
||||
/// calculated total size.
|
||||
///
|
||||
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
|
||||
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
|
||||
/// or [`None`] if starting from the first dimension and ending at the last dimension
|
||||
/// respectively.
|
||||
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dims: &Dims,
|
||||
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Dims: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_size",
|
||||
64 => "__nac3_ndarray_calc_size64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
|
||||
false,
|
||||
);
|
||||
let ndarray_calc_size_fn =
|
||||
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
|
||||
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||
});
|
||||
|
||||
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
|
||||
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_size_fn,
|
||||
&[
|
||||
dims.base_ptr(ctx, generator).into(),
|
||||
dims.size(ctx, generator).into(),
|
||||
begin.into(),
|
||||
end.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
|
||||
/// containing `i32` indices of the flattened index.
|
||||
///
|
||||
/// * `index` - The index to compute the multidimensional index for.
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
index: IntValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_nd_indices",
|
||||
64 => "__nac3_ndarray_calc_nd_indices64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_nd_indices_fn =
|
||||
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.dim_sizes();
|
||||
|
||||
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_nd_indices_fn,
|
||||
&[
|
||||
index.into(),
|
||||
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
|
||||
Box::new(|_, v| v.into_int_value()),
|
||||
Box::new(|_, v| v.into()),
|
||||
)
|
||||
}
|
||||
|
||||
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &Indices,
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Indices: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
debug_assert_eq!(
|
||||
IntType::try_from(indices.element_type(ctx, generator))
|
||||
.map(IntType::get_bit_width)
|
||||
.unwrap_or_default(),
|
||||
llvm_i32.get_bit_width(),
|
||||
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
debug_assert_eq!(
|
||||
indices.size(ctx, generator).get_type().get_bit_width(),
|
||||
llvm_usize.get_bit_width(),
|
||||
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
||||
);
|
||||
|
||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_flatten_index",
|
||||
64 => "__nac3_ndarray_flatten_index64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_flatten_index_fn =
|
||||
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.dim_sizes();
|
||||
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_call(
|
||||
ndarray_flatten_index_fn,
|
||||
&[
|
||||
ndarray_dims.base_ptr(ctx, generator).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.base_ptr(ctx, generator).into(),
|
||||
indices.size(ctx, generator).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap();
|
||||
|
||||
index
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
|
||||
/// multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &Index,
|
||||
) -> IntValue<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Index: ArrayLikeIndexer<'ctx>,
|
||||
{
|
||||
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast",
|
||||
64 => "__nac3_ndarray_calc_broadcast64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(min_ndims, false),
|
||||
|generator, ctx, _, idx| {
|
||||
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
|
||||
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
|
||||
(
|
||||
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
|
||||
)
|
||||
};
|
||||
|
||||
let llvm_usize_const_one = llvm_usize.const_int(1, false);
|
||||
let lhs_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
|
||||
.unwrap();
|
||||
let rhs_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
|
||||
.unwrap();
|
||||
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
|
||||
|
||||
let lhs_eq_rhs = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
|
||||
.unwrap();
|
||||
|
||||
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
is_compatible,
|
||||
"0:ValueError",
|
||||
"operands could not be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
|
||||
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
|
||||
let lhs_ndims = lhs.load_ndims(ctx);
|
||||
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
|
||||
let rhs_ndims = rhs.load_ndims(ctx);
|
||||
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
|
||||
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[
|
||||
lhs_dims.into(),
|
||||
lhs_ndims.into(),
|
||||
rhs_dims.into(),
|
||||
rhs_ndims.into(),
|
||||
out_dims.base_ptr(ctx, generator).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
out_dims,
|
||||
Box::new(|_, v| v.into_int_value()),
|
||||
Box::new(|_, v| v.into()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
|
||||
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
|
||||
/// array `broadcast_idx`.
|
||||
pub fn call_ndarray_calc_broadcast_index<
|
||||
'ctx,
|
||||
G: CodeGenerator + ?Sized,
|
||||
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
|
||||
>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
array: NDArrayValue<'ctx>,
|
||||
broadcast_idx: &BroadcastIdx,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast_idx",
|
||||
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
|
||||
|
||||
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
|
||||
let array_ndims = array.load_ndims(ctx);
|
||||
let broadcast_idx_ptr = unsafe {
|
||||
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
|
||||
Box::new(|_, v| v.into_int_value()),
|
||||
Box::new(|_, v| v.into()),
|
||||
)
|
||||
}
|
||||
|
72
nac3core/src/codegen/irrt/ndarray/array.rs
Normal file
72
nac3core/src/codegen/irrt/ndarray/array.rs
Normal file
@ -0,0 +1,72 @@
|
||||
use inkwell::{types::BasicTypeEnum, values::IntValue};
|
||||
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function,
|
||||
irrt::get_usize_dependent_function_name,
|
||||
values::{ndarray::NDArrayValue, ListValue, ProxyValue, TypedArrayLikeAccessor},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_array_set_and_validate_list_shape`.
|
||||
///
|
||||
/// Deduces the target shape of the `ndarray` from the provided `list`, raising an exception if
|
||||
/// there is any issue with the resultant `shape`.
|
||||
///
|
||||
/// `shape` must be pre-allocated by the caller of this function to `[usize; ndims]`, and must be
|
||||
/// initialized to all `-1`s.
|
||||
pub fn call_nac3_ndarray_array_set_and_validate_list_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
list: ListValue<'ctx>,
|
||||
ndims: IntValue<'ctx>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
|
||||
assert_eq!(ndims.get_type(), llvm_usize);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name =
|
||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_set_and_validate_list_shape");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[list.as_abi_value(ctx).into(), ndims.into(), shape.base_ptr(ctx, generator).into()],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_array_write_list_to_array`.
|
||||
///
|
||||
/// Copies the contents stored in `list` into `ndarray`.
|
||||
///
|
||||
/// The `ndarray` must fulfill the following preconditions:
|
||||
///
|
||||
/// - `ndarray.itemsize`: Must be initialized.
|
||||
/// - `ndarray.ndims`: Must be initialized.
|
||||
/// - `ndarray.shape`: Must be initialized.
|
||||
/// - `ndarray.data`: Must be allocated and contiguous.
|
||||
pub fn call_nac3_ndarray_array_write_list_to_array<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
list: ListValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
assert_eq!(list.get_type().element_type().unwrap(), ctx.ctx.i8_type().into());
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_array_write_list_to_array");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[list.as_abi_value(ctx).into(), ndarray.as_abi_value(ctx).into()],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
298
nac3core/src/codegen/irrt/ndarray/basic.rs
Normal file
298
nac3core/src/codegen/irrt/ndarray/basic.rs
Normal file
@ -0,0 +1,298 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
|
||||
use crate::codegen::{
|
||||
expr::{create_and_call_function, infer_and_call_function},
|
||||
irrt::get_usize_dependent_function_name,
|
||||
types::ProxyType,
|
||||
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_util_assert_shape_no_negative`.
|
||||
///
|
||||
/// Assets that `shape` does not contain negative dimensions.
|
||||
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name =
|
||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_shape_no_negative");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_usize.into()),
|
||||
&[
|
||||
(llvm_usize.into(), shape.size(ctx, generator).into()),
|
||||
(llvm_pusize.into(), shape.base_ptr(ctx, generator).into()),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_util_assert_shape_output_shape_same`.
|
||||
///
|
||||
/// Asserts that `ndarray_shape` and `output_shape` are the same in the context of writing output to
|
||||
/// an `ndarray`.
|
||||
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
output_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(ndarray_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(output_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name =
|
||||
get_usize_dependent_function_name(ctx, "__nac3_ndarray_util_assert_output_shape_same");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_usize.into()),
|
||||
&[
|
||||
(llvm_usize.into(), ndarray_shape.size(ctx, generator).into()),
|
||||
(llvm_pusize.into(), ndarray_shape.base_ptr(ctx, generator).into()),
|
||||
(llvm_usize.into(), output_shape.size(ctx, generator).into()),
|
||||
(llvm_pusize.into(), output_shape.base_ptr(ctx, generator).into()),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_size`.
|
||||
///
|
||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of elements of an
|
||||
/// `ndarray`, corresponding to the value of `ndarray.size`.
|
||||
pub fn call_nac3_ndarray_size<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_size");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_usize.into()),
|
||||
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())],
|
||||
Some("size"),
|
||||
None,
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_nbytes`.
|
||||
///
|
||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the number of bytes consumed by the
|
||||
/// data of the `ndarray`, corresponding to the value of `ndarray.nbytes`.
|
||||
pub fn call_nac3_ndarray_nbytes<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_nbytes");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_usize.into()),
|
||||
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())],
|
||||
Some("nbytes"),
|
||||
None,
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_len`.
|
||||
///
|
||||
/// Returns a [`usize`][CodeGenerator::get_size_type] value of the size of the topmost dimension of
|
||||
/// the `ndarray`, corresponding to the value of `ndarray.__len__`.
|
||||
pub fn call_nac3_ndarray_len<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_len");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_usize.into()),
|
||||
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())],
|
||||
Some("len"),
|
||||
None,
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_is_c_contiguous`.
|
||||
///
|
||||
/// Returns an `i1` value indicating whether the `ndarray` is C-contiguous.
|
||||
pub fn call_nac3_ndarray_is_c_contiguous<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_ndarray = ndarray.get_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_is_c_contiguous");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_i1.into()),
|
||||
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())],
|
||||
Some("is_c_contiguous"),
|
||||
None,
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_get_nth_pelement`.
|
||||
///
|
||||
/// Returns a [`PointerValue`] to the `index`-th flattened element of the `ndarray`.
|
||||
pub fn call_nac3_ndarray_get_nth_pelement<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
index: IntValue<'ctx>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_ndarray = ndarray.get_type();
|
||||
|
||||
assert_eq!(index.get_type(), llvm_usize);
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_nth_pelement");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_pi8.into()),
|
||||
&[
|
||||
(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()),
|
||||
(llvm_usize.into(), index.into()),
|
||||
],
|
||||
Some("pelement"),
|
||||
None,
|
||||
)
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_get_pelement_by_indices`.
|
||||
///
|
||||
/// `indices` must have the same number of elements as the number of dimensions in `ndarray`.
|
||||
///
|
||||
/// Returns a [`PointerValue`] to the element indexed by `indices`.
|
||||
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
let llvm_ndarray = ndarray.get_type();
|
||||
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_get_pelement_by_indices");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(llvm_pi8.into()),
|
||||
&[
|
||||
(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into()),
|
||||
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
|
||||
],
|
||||
Some("pelement"),
|
||||
None,
|
||||
)
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_set_strides_by_shape`.
|
||||
///
|
||||
/// Sets `ndarray.strides` assuming that `ndarray.shape` is C-contiguous.
|
||||
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let llvm_ndarray = ndarray.get_type();
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_set_strides_by_shape");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[(llvm_ndarray.as_abi_type().into(), ndarray.as_abi_value(ctx).into())],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_copy_data`.
|
||||
///
|
||||
/// Copies all elements from `src_ndarray` to `dst_ndarray` using their flattened views. The number
|
||||
/// of elements in `src_ndarray` must be greater than or equal to the number of elements in
|
||||
/// `dst_ndarray`.
|
||||
pub fn call_nac3_ndarray_copy_data<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_copy_data");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
80
nac3core/src/codegen/irrt/ndarray/broadcast.rs
Normal file
80
nac3core/src/codegen/irrt/ndarray/broadcast.rs
Normal file
@ -0,0 +1,80 @@
|
||||
use inkwell::values::IntValue;
|
||||
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function,
|
||||
irrt::get_usize_dependent_function_name,
|
||||
types::{ndarray::ShapeEntryType, ProxyType},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAccessor,
|
||||
TypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_broadcast_to`.
|
||||
///
|
||||
/// Attempts to broadcast `src_ndarray` to the new shape defined by `dst_ndarray`.
|
||||
///
|
||||
/// `dst_ndarray` must meet the following preconditions:
|
||||
///
|
||||
/// - `dst_ndarray.ndims` must be initialized and matching the length of `dst_ndarray.shape`.
|
||||
/// - `dst_ndarray.shape` must be initialized and contains the target broadcast shape.
|
||||
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
|
||||
pub fn call_nac3_ndarray_broadcast_to<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_to");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[src_ndarray.as_abi_value(ctx).into(), dst_ndarray.as_abi_value(ctx).into()],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_broadcast_shapes`.
|
||||
///
|
||||
/// Attempts to calculate the resultant shape from broadcasting all shapes in `shape_entries`,
|
||||
/// writing the result to `dst_shape`.
|
||||
pub fn call_nac3_ndarray_broadcast_shapes<'ctx, G, Shape>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
num_shape_entries: IntValue<'ctx>,
|
||||
shape_entries: ArraySliceValue<'ctx>,
|
||||
dst_ndims: IntValue<'ctx>,
|
||||
dst_shape: &Shape,
|
||||
) where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||
{
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert_eq!(num_shape_entries.get_type(), llvm_usize);
|
||||
assert!(ShapeEntryType::is_representable(
|
||||
shape_entries.base_ptr(ctx, generator).get_type(),
|
||||
llvm_usize,
|
||||
)
|
||||
.is_ok());
|
||||
assert_eq!(dst_ndims.get_type(), llvm_usize);
|
||||
assert_eq!(dst_shape.element_type(ctx, generator), llvm_usize.into());
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_broadcast_shapes");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[
|
||||
num_shape_entries.into(),
|
||||
shape_entries.base_ptr(ctx, generator).into(),
|
||||
dst_ndims.into(),
|
||||
dst_shape.base_ptr(ctx, generator).into(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
34
nac3core/src/codegen/irrt/ndarray/indexing.rs
Normal file
34
nac3core/src/codegen/irrt/ndarray/indexing.rs
Normal file
@ -0,0 +1,34 @@
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function,
|
||||
irrt::get_usize_dependent_function_name,
|
||||
values::{ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_index`.
|
||||
///
|
||||
/// Performs [basic indexing](https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
|
||||
/// on `src_ndarray` using `indices`, writing the result to `dst_ndarray`, corresponding to the
|
||||
/// operation `dst_ndarray = src_ndarray[indices]`.
|
||||
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
indices: ArraySliceValue<'ctx>,
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_index");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[
|
||||
indices.size(ctx, generator).into(),
|
||||
indices.base_ptr(ctx, generator).into(),
|
||||
src_ndarray.as_abi_value(ctx).into(),
|
||||
dst_ndarray.as_abi_value(ctx).into(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
81
nac3core/src/codegen/irrt/ndarray/iter.rs
Normal file
81
nac3core/src/codegen/irrt/ndarray/iter.rs
Normal file
@ -0,0 +1,81 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, IntValue},
|
||||
AddressSpace,
|
||||
};
|
||||
|
||||
use crate::codegen::{
|
||||
expr::{create_and_call_function, infer_and_call_function},
|
||||
irrt::get_usize_dependent_function_name,
|
||||
types::ProxyType,
|
||||
values::{
|
||||
ndarray::{NDArrayValue, NDIterValue},
|
||||
ProxyValue, TypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_nditer_initialize`.
|
||||
///
|
||||
/// Initializes the `iter` object.
|
||||
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
iter: NDIterValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(indices.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_initialize");
|
||||
|
||||
create_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[
|
||||
(iter.get_type().as_abi_type().into(), iter.as_abi_value(ctx).into()),
|
||||
(ndarray.get_type().as_abi_type().into(), ndarray.as_abi_value(ctx).into()),
|
||||
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_nditer_initialize_has_element`.
|
||||
///
|
||||
/// Returns an `i1` value indicating whether there are elements left to traverse for the `iter`
|
||||
/// object.
|
||||
pub fn call_nac3_nditer_has_element<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
iter: NDIterValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_has_element");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
Some(ctx.ctx.bool_type().into()),
|
||||
&[iter.as_abi_value(ctx).into()],
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_nditer_next`.
|
||||
///
|
||||
/// Moves `iter` to point to the next element.
|
||||
pub fn call_nac3_nditer_next<'ctx>(ctx: &CodeGenContext<'ctx, '_>, iter: NDIterValue<'ctx>) {
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_nditer_next");
|
||||
|
||||
infer_and_call_function(ctx, &name, None, &[iter.as_abi_value(ctx).into()], None, None);
|
||||
}
|
65
nac3core/src/codegen/irrt/ndarray/matmul.rs
Normal file
65
nac3core/src/codegen/irrt/ndarray/matmul.rs
Normal file
@ -0,0 +1,65 @@
|
||||
use inkwell::{types::BasicTypeEnum, values::IntValue};
|
||||
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function, irrt::get_usize_dependent_function_name,
|
||||
values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_matmul_calculate_shapes`.
|
||||
///
|
||||
/// Calculates the broadcasted shapes for `a`, `b`, and the `ndarray` holding the final values of
|
||||
/// `a @ b`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
final_ndims: IntValue<'ctx>,
|
||||
new_a_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
new_b_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
dst_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(a_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(b_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(new_a_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(new_b_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
assert_eq!(
|
||||
BasicTypeEnum::try_from(dst_shape.element_type(ctx, generator)).unwrap(),
|
||||
llvm_usize.into()
|
||||
);
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_matmul_calculate_shapes");
|
||||
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[
|
||||
a_shape.size(ctx, generator).into(),
|
||||
a_shape.base_ptr(ctx, generator).into(),
|
||||
b_shape.size(ctx, generator).into(),
|
||||
b_shape.base_ptr(ctx, generator).into(),
|
||||
final_ndims.into(),
|
||||
new_a_shape.base_ptr(ctx, generator).into(),
|
||||
new_b_shape.base_ptr(ctx, generator).into(),
|
||||
dst_shape.base_ptr(ctx, generator).into(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
17
nac3core/src/codegen/irrt/ndarray/mod.rs
Normal file
17
nac3core/src/codegen/irrt/ndarray/mod.rs
Normal file
@ -0,0 +1,17 @@
|
||||
pub use array::*;
|
||||
pub use basic::*;
|
||||
pub use broadcast::*;
|
||||
pub use indexing::*;
|
||||
pub use iter::*;
|
||||
pub use matmul::*;
|
||||
pub use reshape::*;
|
||||
pub use transpose::*;
|
||||
|
||||
mod array;
|
||||
mod basic;
|
||||
mod broadcast;
|
||||
mod indexing;
|
||||
mod iter;
|
||||
mod matmul;
|
||||
mod reshape;
|
||||
mod transpose;
|
39
nac3core/src/codegen/irrt/ndarray/reshape.rs
Normal file
39
nac3core/src/codegen/irrt/ndarray/reshape.rs
Normal file
@ -0,0 +1,39 @@
|
||||
use inkwell::values::IntValue;
|
||||
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function,
|
||||
irrt::get_usize_dependent_function_name,
|
||||
values::{ArrayLikeValue, ArraySliceValue},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_reshape_resolve_and_check_new_shape`.
|
||||
///
|
||||
/// Resolves unknown dimensions in `new_shape` for `numpy.reshape(<ndarray>, new_shape)`, raising an
|
||||
/// assertion if multiple dimensions are unknown (`-1`).
|
||||
pub fn call_nac3_ndarray_reshape_resolve_and_check_new_shape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
size: IntValue<'ctx>,
|
||||
new_ndims: IntValue<'ctx>,
|
||||
new_shape: ArraySliceValue<'ctx>,
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert_eq!(size.get_type(), llvm_usize);
|
||||
assert_eq!(new_ndims.get_type(), llvm_usize);
|
||||
assert_eq!(new_shape.element_type(ctx, generator), llvm_usize.into());
|
||||
|
||||
let name = get_usize_dependent_function_name(
|
||||
ctx,
|
||||
"__nac3_ndarray_reshape_resolve_and_check_new_shape",
|
||||
);
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[size.into(), new_ndims.into(), new_shape.base_ptr(ctx, generator).into()],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
48
nac3core/src/codegen/irrt/ndarray/transpose.rs
Normal file
48
nac3core/src/codegen/irrt/ndarray/transpose.rs
Normal file
@ -0,0 +1,48 @@
|
||||
use inkwell::{values::IntValue, AddressSpace};
|
||||
|
||||
use crate::codegen::{
|
||||
expr::infer_and_call_function,
|
||||
irrt::get_usize_dependent_function_name,
|
||||
values::{ndarray::NDArrayValue, ProxyValue, TypedArrayLikeAccessor},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_transpose`.
|
||||
///
|
||||
/// Creates a transpose view of `src_ndarray` and writes the result to `dst_ndarray`.
|
||||
///
|
||||
/// `dst_ndarray` must fulfill the following preconditions:
|
||||
///
|
||||
/// - `dst_ndarray.ndims` must be initialized and must be equal to `src_ndarray.ndims`.
|
||||
/// - `dst_ndarray.shape` must be allocated and may contain uninitialized values.
|
||||
/// - `dst_ndarray.strides` must be allocated and may contain uninitialized values.
|
||||
pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src_ndarray: NDArrayValue<'ctx>,
|
||||
dst_ndarray: NDArrayValue<'ctx>,
|
||||
axes: Option<&impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>>,
|
||||
) {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
assert!(axes.is_none_or(|axes| axes.size(ctx, generator).get_type() == llvm_usize));
|
||||
assert!(axes.is_none_or(|axes| axes.element_type(ctx, generator) == llvm_usize.into()));
|
||||
|
||||
let name = get_usize_dependent_function_name(ctx, "__nac3_ndarray_transpose");
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
&name,
|
||||
None,
|
||||
&[
|
||||
src_ndarray.as_abi_value(ctx).into(),
|
||||
dst_ndarray.as_abi_value(ctx).into(),
|
||||
axes.map_or(llvm_usize.const_zero(), |axes| axes.size(ctx, generator)).into(),
|
||||
axes.map_or(llvm_usize.ptr_type(AddressSpace::default()).const_null(), |axes| {
|
||||
axes.base_ptr(ctx, generator)
|
||||
})
|
||||
.into(),
|
||||
],
|
||||
None,
|
||||
None,
|
||||
);
|
||||
}
|
56
nac3core/src/codegen/irrt/range.rs
Normal file
56
nac3core/src/codegen/irrt/range.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use inkwell::{
|
||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||
IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
|
||||
/// Invokes the `__nac3_range_slice_len` in IRRT.
|
||||
///
|
||||
/// - `start`: The `i32` start value for the slice.
|
||||
/// - `end`: The `i32` end value for the slice.
|
||||
/// - `step`: The `i32` step value for the slice.
|
||||
///
|
||||
/// Returns an `i32` value of the length of the slice.
|
||||
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
start: IntValue<'ctx>,
|
||||
end: IntValue<'ctx>,
|
||||
step: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
const SYMBOL: &str = "__nac3_range_slice_len";
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
assert_eq!(start.get_type(), llvm_i32);
|
||||
assert_eq!(end.get_type(), llvm_i32);
|
||||
assert_eq!(step.get_type(), llvm_i32);
|
||||
|
||||
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||
let fn_t = llvm_i32.fn_type(&[llvm_i32.into(), llvm_i32.into(), llvm_i32.into()], false);
|
||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||
});
|
||||
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
"0:ValueError",
|
||||
"step must not be zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
ctx.builder
|
||||
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
39
nac3core/src/codegen/irrt/slice.rs
Normal file
39
nac3core/src/codegen/irrt/slice.rs
Normal file
@ -0,0 +1,39 @@
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue};
|
||||
use itertools::Either;
|
||||
|
||||
use nac3parser::ast::Expr;
|
||||
|
||||
use crate::{
|
||||
codegen::{CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
/// this function allows index out of range, since python
|
||||
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
|
||||
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
|
||||
i: &Expr<Option<Type>>,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
length: IntValue<'ctx>,
|
||||
) -> Result<Option<IntValue<'ctx>>, String> {
|
||||
const SYMBOL: &str = "__nac3_slice_index_bound";
|
||||
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
|
||||
let i32_t = ctx.ctx.i32_type();
|
||||
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
|
||||
ctx.module.add_function(SYMBOL, fn_t, None)
|
||||
});
|
||||
|
||||
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
|
||||
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
Ok(Some(
|
||||
ctx.builder
|
||||
.build_call(func, &[i.into(), length.into()], "bounded_ind")
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
45
nac3core/src/codegen/irrt/string.rs
Normal file
45
nac3core/src/codegen/irrt/string.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
|
||||
use itertools::Either;
|
||||
|
||||
use super::get_usize_dependent_function_name;
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
||||
pub fn call_string_eq<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
str1_ptr: PointerValue<'ctx>,
|
||||
str1_len: IntValue<'ctx>,
|
||||
str2_ptr: PointerValue<'ctx>,
|
||||
str2_len: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
|
||||
let func_name = get_usize_dependent_function_name(ctx, "nac3_str_eq");
|
||||
|
||||
let func = ctx.module.get_function(&func_name).unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
&func_name,
|
||||
llvm_i1.fn_type(
|
||||
&[
|
||||
str1_ptr.get_type().into(),
|
||||
str1_len.get_type().into(),
|
||||
str2_ptr.get_type().into(),
|
||||
str2_len.get_type().into(),
|
||||
],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
func,
|
||||
&[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
|
||||
"str_eq_call",
|
||||
)
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
intrinsics::Intrinsic,
|
||||
types::{AnyTypeEnum::IntType, FloatType},
|
||||
types::AnyTypeEnum::IntType,
|
||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
@ -9,34 +8,6 @@ use itertools::Either;
|
||||
|
||||
use super::CodeGenContext;
|
||||
|
||||
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
|
||||
/// functions.
|
||||
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
|
||||
// Standard LLVM floating-point types
|
||||
if ft == ctx.f16_type() {
|
||||
return "f16";
|
||||
}
|
||||
if ft == ctx.f32_type() {
|
||||
return "f32";
|
||||
}
|
||||
if ft == ctx.f64_type() {
|
||||
return "f64";
|
||||
}
|
||||
if ft == ctx.f128_type() {
|
||||
return "f128";
|
||||
}
|
||||
|
||||
// Non-standard floating-point types
|
||||
if ft == ctx.x86_f80_type() {
|
||||
return "f80";
|
||||
}
|
||||
if ft == ctx.ppc_f128_type() {
|
||||
return "ppcf128";
|
||||
}
|
||||
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
|
||||
/// intrinsic.
|
||||
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
|
||||
@ -54,7 +25,7 @@ pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue
|
||||
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
|
||||
/// Invokes the [`llvm.va_end`](https://llvm.org/docs/LangRef.html#llvm-va-end-intrinsic)
|
||||
/// intrinsic.
|
||||
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
|
||||
const FN_NAME: &str = "llvm.va_end";
|
||||
@ -201,6 +172,49 @@ pub fn call_memcpy_generic<'ctx>(
|
||||
call_memcpy(ctx, dest, src, len, is_volatile);
|
||||
}
|
||||
|
||||
/// Invokes the `llvm.memcpy` intrinsic.
|
||||
///
|
||||
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
|
||||
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
|
||||
/// Moreover, `len` now refers to the number of elements to copy (rather than number of bytes to
|
||||
/// copy).
|
||||
pub fn call_memcpy_generic_array<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dest: PointerValue<'ctx>,
|
||||
src: PointerValue<'ctx>,
|
||||
len: IntValue<'ctx>,
|
||||
is_volatile: IntValue<'ctx>,
|
||||
) {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_sizeof_expr_t = llvm_i8.size_of().get_type();
|
||||
|
||||
let dest_elem_t = dest.get_type().get_element_type();
|
||||
let src_elem_t = src.get_type().get_element_type();
|
||||
|
||||
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
||||
dest
|
||||
} else {
|
||||
ctx.builder
|
||||
.build_bit_cast(dest, llvm_p0i8, "")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
};
|
||||
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
|
||||
src
|
||||
} else {
|
||||
ctx.builder
|
||||
.build_bit_cast(src, llvm_p0i8, "")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let len = ctx.builder.build_int_z_extend_or_bit_cast(len, llvm_sizeof_expr_t, "").unwrap();
|
||||
let len = ctx.builder.build_int_mul(len, src_elem_t.size_of().unwrap(), "").unwrap();
|
||||
|
||||
call_memcpy(ctx, dest, src, len, is_volatile);
|
||||
}
|
||||
|
||||
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
|
||||
///
|
||||
/// Arguments:
|
||||
@ -343,3 +357,25 @@ pub fn call_float_powi<'ctx>(
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Invokes the [`llvm.ctpop`](https://llvm.org/docs/LangRef.html#llvm-ctpop-intrinsic) intrinsic.
|
||||
pub fn call_int_ctpop<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
src: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
const FN_NAME: &str = "llvm.ctpop";
|
||||
|
||||
let llvm_src_t = src.get_type();
|
||||
|
||||
let intrinsic_fn = Intrinsic::find(FN_NAME)
|
||||
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
|
||||
.unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
|
||||
.map(CallSiteValue::try_as_basic_value)
|
||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
||||
.map(Either::unwrap_left)
|
||||
.unwrap()
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
use std::{
|
||||
cell::OnceCell,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
@ -19,7 +20,7 @@ use inkwell::{
|
||||
module::Module,
|
||||
passes::PassBuilderOptions,
|
||||
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
|
||||
types::{AnyType, BasicType, BasicTypeEnum},
|
||||
types::{AnyType, BasicType, BasicTypeEnum, IntType},
|
||||
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
@ -30,18 +31,21 @@ use nac3parser::ast::{Location, Stmt, StrRef};
|
||||
|
||||
use crate::{
|
||||
symbol_resolver::{StaticValue, SymbolResolver},
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
|
||||
toplevel::{
|
||||
helper::{extract_ndims, PrimDef},
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
TopLevelContext, TopLevelDef,
|
||||
},
|
||||
typecheck::{
|
||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
||||
},
|
||||
};
|
||||
use classes::{ListType, NDArrayType, ProxyType, RangeType};
|
||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||
use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType, TupleType};
|
||||
|
||||
pub mod builtin_fns;
|
||||
pub mod classes;
|
||||
pub mod concrete_type;
|
||||
pub mod expr;
|
||||
pub mod extern_fns;
|
||||
@ -50,6 +54,8 @@ pub mod irrt;
|
||||
pub mod llvm_intrinsics;
|
||||
pub mod numpy;
|
||||
pub mod stmt;
|
||||
pub mod types;
|
||||
pub mod values;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
@ -221,14 +227,33 @@ pub struct CodeGenContext<'ctx, 'a> {
|
||||
|
||||
/// The current source location.
|
||||
pub current_loc: Location,
|
||||
|
||||
/// The cached type of `size_t`.
|
||||
llvm_usize: OnceCell<IntType<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
impl<'ctx> CodeGenContext<'ctx, '_> {
|
||||
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
|
||||
/// contains a [terminator statement][BasicBlock::get_terminator].
|
||||
pub fn is_terminated(&self) -> bool {
|
||||
self.builder.get_insert_block().and_then(BasicBlock::get_terminator).is_some()
|
||||
}
|
||||
|
||||
/// Returns a [`IntType`] representing `size_t` for the compilation target as specified by
|
||||
/// [`self.registry`][WorkerRegistry].
|
||||
pub fn get_size_type(&self) -> IntType<'ctx> {
|
||||
*self.llvm_usize.get_or_init(|| {
|
||||
self.ctx.ptr_sized_int_type(
|
||||
&self
|
||||
.registry
|
||||
.llvm_options
|
||||
.create_target_machine()
|
||||
.map(|tm| tm.get_target_data())
|
||||
.unwrap(),
|
||||
None,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type Fp = Box<dyn Fn(&Module) + Send + Sync>;
|
||||
@ -476,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
|
||||
let ty_enum = unifier.get_ty(ty);
|
||||
let result = match &*ty_enum {
|
||||
TModule {module_id, attributes} => {
|
||||
let top_level_defs = top_level.definitions.read();
|
||||
let definition = top_level_defs.get(module_id.0).unwrap();
|
||||
let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else {
|
||||
unreachable!()
|
||||
};
|
||||
let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) {
|
||||
t.ptr_type(AddressSpace::default()).into()
|
||||
} else {
|
||||
let struct_type = ctx.opaque_struct_type(&name.to_string());
|
||||
type_cache.insert(
|
||||
unifier.get_representative(ty),
|
||||
struct_type.ptr_type(AddressSpace::default()).into(),
|
||||
);
|
||||
let module_fields: Vec<BasicTypeEnum<'_>> = attribute_fields.iter()
|
||||
.map(|f| {
|
||||
get_llvm_type(
|
||||
ctx,
|
||||
module,
|
||||
generator,
|
||||
unifier,
|
||||
top_level,
|
||||
type_cache,
|
||||
attributes[&f.0].0,
|
||||
)
|
||||
})
|
||||
.collect_vec();
|
||||
struct_type.set_body(&module_fields, false);
|
||||
struct_type.ptr_type(AddressSpace::default()).into()
|
||||
};
|
||||
return ty;
|
||||
},
|
||||
TObj { obj_id, fields, .. } => {
|
||||
// check to avoid treating non-class primitives as classes
|
||||
if PrimDef::contains_id(*obj_id) {
|
||||
@ -505,16 +562,17 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
*params.iter().next().unwrap().1,
|
||||
);
|
||||
|
||||
ListType::new(generator, ctx, element_type).as_base_type().into()
|
||||
ListType::new_with_generator(generator, ctx, element_type).as_abi_type().into()
|
||||
}
|
||||
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(unifier, ty);
|
||||
let ndims = extract_ndims(unifier, ndims);
|
||||
let element_type = get_llvm_type(
|
||||
ctx, module, generator, unifier, top_level, type_cache, dtype,
|
||||
);
|
||||
|
||||
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
|
||||
NDArrayType::new_with_generator(generator, ctx, element_type, ndims).as_abi_type().into()
|
||||
}
|
||||
|
||||
_ => unreachable!(
|
||||
@ -568,7 +626,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
|
||||
})
|
||||
.collect_vec();
|
||||
ctx.struct_type(&fields, false).into()
|
||||
TupleType::new_with_generator(generator, ctx, &fields).as_abi_type().into()
|
||||
}
|
||||
TVirtual { .. } => unimplemented!(),
|
||||
_ => unreachable!("{}", ty_enum.get_type_name()),
|
||||
@ -742,7 +800,7 @@ pub fn gen_func_impl<
|
||||
Some(t) => t.as_basic_type_enum(),
|
||||
}
|
||||
}),
|
||||
(primitives.range, RangeType::new(context).as_base_type().into()),
|
||||
(primitives.range, RangeType::new_with_generator(generator, context).as_abi_type().into()),
|
||||
(primitives.exception, {
|
||||
let name = "Exception";
|
||||
if let Some(t) = module.get_struct_type(name) {
|
||||
@ -981,8 +1039,20 @@ pub fn gen_func_impl<
|
||||
need_sret: has_sret,
|
||||
current_loc: Location::default(),
|
||||
debug_info: (dibuilder, compile_unit, func_scope.as_debug_info_scope()),
|
||||
llvm_usize: OnceCell::default(),
|
||||
};
|
||||
|
||||
let target_llvm_usize = context.ptr_sized_int_type(
|
||||
®istry.llvm_options.create_target_machine().map(|tm| tm.get_target_data()).unwrap(),
|
||||
None,
|
||||
);
|
||||
let generator_llvm_usize = generator.get_size_type(context);
|
||||
assert_eq!(
|
||||
generator_llvm_usize,
|
||||
target_llvm_usize,
|
||||
"CodeGenerator (size_t = {generator_llvm_usize}) is not compatible with CodeGen Target (size_t = {target_llvm_usize})",
|
||||
);
|
||||
|
||||
let loc = code_gen_context.debug_info.0.create_debug_location(
|
||||
context,
|
||||
row as u32,
|
||||
@ -1118,3 +1188,106 @@ fn gen_in_range_check<'ctx>(
|
||||
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
|
||||
format!("__{}_va_count", &arg_name).into()
|
||||
}
|
||||
|
||||
/// Returns the alignment of the type.
|
||||
///
|
||||
/// This is necessary as `get_alignment` is not implemented as part of [`BasicType`].
|
||||
pub fn get_type_alignment<'ctx>(ty: impl Into<BasicTypeEnum<'ctx>>) -> IntValue<'ctx> {
|
||||
match ty.into() {
|
||||
BasicTypeEnum::ArrayType(ty) => ty.get_alignment(),
|
||||
BasicTypeEnum::FloatType(ty) => ty.get_alignment(),
|
||||
BasicTypeEnum::IntType(ty) => ty.get_alignment(),
|
||||
BasicTypeEnum::PointerType(ty) => ty.get_alignment(),
|
||||
BasicTypeEnum::StructType(ty) => ty.get_alignment(),
|
||||
BasicTypeEnum::VectorType(ty) => ty.get_alignment(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Inserts an `alloca` instruction with allocation `size` given in bytes and the alignment of the
|
||||
/// given type.
|
||||
///
|
||||
/// The returned [`PointerValue`] will have a type of `i8*`, a size of at least `size`, and will be
|
||||
/// aligned with the alignment of `align_ty`.
|
||||
pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
align_ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
/// Round `val` up to its modulo `power_of_two`.
|
||||
fn round_up<'ctx>(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
val: IntValue<'ctx>,
|
||||
power_of_two: IntValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
debug_assert_eq!(
|
||||
val.get_type().get_bit_width(),
|
||||
power_of_two.get_type().get_bit_width(),
|
||||
"`val` ({}) and `power_of_two` ({}) must be the same type",
|
||||
val.get_type(),
|
||||
power_of_two.get_type(),
|
||||
);
|
||||
|
||||
let llvm_val_t = val.get_type();
|
||||
|
||||
let max_rem =
|
||||
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "").unwrap();
|
||||
ctx.builder
|
||||
.build_and(
|
||||
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
|
||||
ctx.builder.build_not(max_rem, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let align_ty = align_ty.into();
|
||||
|
||||
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap();
|
||||
|
||||
debug_assert_eq!(
|
||||
size.get_type().get_bit_width(),
|
||||
llvm_usize.get_bit_width(),
|
||||
"Expected size_t ({}) for parameter `size` of `aligned_alloca`, got {}",
|
||||
llvm_usize,
|
||||
size.get_type(),
|
||||
);
|
||||
|
||||
let alignment = get_type_alignment(align_ty);
|
||||
let alignment = ctx.builder.build_int_truncate_or_bit_cast(alignment, llvm_usize, "").unwrap();
|
||||
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let alignment_bitcount = llvm_intrinsics::call_int_ctpop(ctx, alignment, None);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
alignment_bitcount,
|
||||
alignment_bitcount.get_type().const_int(1, false),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
"0:AssertionError",
|
||||
"Expected power-of-two alignment for aligned_alloca, got {0}",
|
||||
[Some(alignment), None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let buffer_size = round_up(ctx, size, alignment);
|
||||
let aligned_slices = ctx.builder.build_int_unsigned_div(buffer_size, alignment, "").unwrap();
|
||||
|
||||
// Just to be absolutely sure, alloca in [i8 x alignment] slices
|
||||
let buffer = ctx.builder.build_array_alloca(align_ty, aligned_slices, "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_bit_cast(buffer, llvm_pi8, name.unwrap_or_default())
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
basic_block::BasicBlock,
|
||||
builder::Builder,
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||
IntPredicate,
|
||||
@ -12,11 +13,15 @@ use nac3parser::ast::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
||||
expr::{destructure_range, gen_binop_expr},
|
||||
gen_in_range_check,
|
||||
irrt::{handle_slice_indices, list_slice_assignment},
|
||||
macros::codegen_unreachable,
|
||||
types::{ndarray::NDArrayType, RangeType},
|
||||
values::{
|
||||
ndarray::{RustNDIndex, ScalarOrNDArray},
|
||||
ArrayLikeIndexer, ArraySliceValue, ListValue, ProxyValue,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::{
|
||||
@ -302,7 +307,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// Handle list item assignment
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let target_item_ty = iter_type_vars(list_params).next().unwrap().ty;
|
||||
|
||||
let target = generator
|
||||
@ -310,7 +315,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, target_ty)?
|
||||
.into_pointer_value();
|
||||
let target = ListValue::from_ptr_val(target, llvm_usize, None);
|
||||
let target = ListValue::from_pointer_value(target, llvm_usize, None);
|
||||
|
||||
if let ExprKind::Slice { .. } = &key.node {
|
||||
// Handle assigning to a slice
|
||||
@ -331,7 +336,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
|
||||
let value =
|
||||
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
|
||||
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
||||
let value = ListValue::from_pointer_value(value, llvm_usize, None);
|
||||
|
||||
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
|
||||
let Some(src_ind) = handle_slice_indices(
|
||||
@ -363,10 +368,8 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, key_ty)?
|
||||
.into_int_value();
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_int_s_extend(index, generator.get_size_type(ctx.ctx), "sext")
|
||||
.unwrap();
|
||||
let index =
|
||||
ctx.builder.build_int_s_extend(index, ctx.get_size_type(), "sext").unwrap();
|
||||
|
||||
// handle negative index
|
||||
let is_negative = ctx
|
||||
@ -374,7 +377,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
index,
|
||||
generator.get_size_type(ctx.ctx).const_zero(),
|
||||
ctx.get_size_type().const_zero(),
|
||||
"is_neg",
|
||||
)
|
||||
.unwrap();
|
||||
@ -411,7 +414,51 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// Handle NDArray item assignment
|
||||
todo!("ndarray subscript assignment is not yet implemented");
|
||||
// Process target
|
||||
let target = generator
|
||||
.gen_expr(ctx, target)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, target_ty)?;
|
||||
|
||||
// Process key
|
||||
let key = RustNDIndex::from_subscript_expr(generator, ctx, key)?;
|
||||
|
||||
// Process value
|
||||
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
||||
|
||||
// Reference code:
|
||||
// ```python
|
||||
// target = target[key]
|
||||
// value = np.asarray(value)
|
||||
//
|
||||
// shape = np.broadcast_shape((target, value))
|
||||
//
|
||||
// target = np.broadcast_to(target, shape)
|
||||
// value = np.broadcast_to(value, shape)
|
||||
//
|
||||
// # ...and finally copy 1-1 from value to target.
|
||||
// ```
|
||||
|
||||
let target = NDArrayType::from_unifier_type(generator, ctx, target_ty)
|
||||
.map_pointer_value(target.into_pointer_value(), None);
|
||||
let target = target.index(generator, ctx, &key);
|
||||
|
||||
let value = ScalarOrNDArray::from_value(generator, ctx, (value_ty, value))
|
||||
.to_ndarray(generator, ctx);
|
||||
|
||||
let broadcast_ndims =
|
||||
[target.get_type().ndims(), value.get_type().ndims()].into_iter().max().unwrap();
|
||||
let broadcast_result = NDArrayType::new(
|
||||
ctx,
|
||||
value.get_type().element_type(),
|
||||
broadcast_ndims,
|
||||
)
|
||||
.broadcast(generator, ctx, &[target, value]);
|
||||
|
||||
let target = broadcast_result.ndarrays[0];
|
||||
let value = broadcast_result.ndarrays[1];
|
||||
|
||||
target.copy_data_from(ctx, value);
|
||||
}
|
||||
_ => {
|
||||
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
|
||||
@ -435,7 +482,7 @@ pub fn gen_for<G: CodeGenerator>(
|
||||
let var_assignment = ctx.var_assignment.clone();
|
||||
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_t = generator.get_size_type(ctx.ctx);
|
||||
let size_t = ctx.get_size_type();
|
||||
let zero = int32.const_zero();
|
||||
let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
|
||||
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
|
||||
@ -463,7 +510,8 @@ pub fn gen_for<G: CodeGenerator>(
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
||||
let iter_val =
|
||||
RangeType::new(ctx).map_pointer_value(iter_val.into_pointer_value(), Some("range"));
|
||||
// Internal variable for loop; Cannot be assigned
|
||||
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
||||
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
||||
@ -615,11 +663,25 @@ pub fn gen_for<G: CodeGenerator>(
|
||||
#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
|
||||
pub struct BreakContinueHooks<'ctx> {
|
||||
/// The [exit block][`BasicBlock`] to branch to when `break`-ing out of a loop.
|
||||
pub exit_bb: BasicBlock<'ctx>,
|
||||
exit_bb: BasicBlock<'ctx>,
|
||||
|
||||
/// The [latch basic block][`BasicBlock`] to branch to for `continue`-ing to the next iteration
|
||||
/// of the loop.
|
||||
pub latch_bb: BasicBlock<'ctx>,
|
||||
latch_bb: BasicBlock<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> BreakContinueHooks<'ctx> {
|
||||
/// Creates a [`br` instruction][Builder::build_unconditional_branch] to the exit
|
||||
/// [`BasicBlock`], as if by calling `break`.
|
||||
pub fn build_break_branch(&self, builder: &Builder<'ctx>) {
|
||||
builder.build_unconditional_branch(self.exit_bb).unwrap();
|
||||
}
|
||||
|
||||
/// Creates a [`br` instruction][Builder::build_unconditional_branch] to the latch
|
||||
/// [`BasicBlock`], as if by calling `continue`.
|
||||
pub fn build_continue_branch(&self, builder: &Builder<'ctx>) {
|
||||
builder.build_unconditional_branch(self.latch_bb).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a C-style `for` construct using lambdas, similar to the following C code:
|
||||
|
@ -16,8 +16,8 @@ use nac3parser::{
|
||||
use parking_lot::RwLock;
|
||||
|
||||
use super::{
|
||||
classes::{ListType, NDArrayType, ProxyType, RangeType},
|
||||
concrete_type::ConcreteTypeStore,
|
||||
types::{ndarray::NDArrayType, ListType, ProxyType, RangeType},
|
||||
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
|
||||
DefaultCodeGenerator, WithCall, WorkerRegistry,
|
||||
};
|
||||
@ -36,7 +36,6 @@ use crate::{
|
||||
struct Resolver {
|
||||
id_to_type: HashMap<StrRef, Type>,
|
||||
id_to_def: RwLock<HashMap<StrRef, DefinitionId>>,
|
||||
class_names: HashMap<StrRef, Type>,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
@ -98,19 +97,18 @@ fn test_primitives() {
|
||||
"};
|
||||
let statements = parse_program(source, FileName::default()).unwrap();
|
||||
|
||||
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
|
||||
let context = inkwell::context::Context::create();
|
||||
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
|
||||
let mut unifier = composer.unifier.clone();
|
||||
let primitives = composer.primitives_ty;
|
||||
let top_level = Arc::new(composer.make_top_level_context());
|
||||
unifier.top_level = Some(top_level.clone());
|
||||
|
||||
let resolver = Arc::new(Resolver {
|
||||
id_to_type: HashMap::new(),
|
||||
id_to_def: RwLock::new(HashMap::new()),
|
||||
class_names: HashMap::default(),
|
||||
}) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
let resolver =
|
||||
Arc::new(Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) })
|
||||
as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
|
||||
let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()];
|
||||
let signature = FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
@ -263,7 +261,8 @@ fn test_simple_call() {
|
||||
"};
|
||||
let statements_2 = parse_program(source_2, FileName::default()).unwrap();
|
||||
|
||||
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
|
||||
let context = inkwell::context::Context::create();
|
||||
let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
|
||||
let mut unifier = composer.unifier.clone();
|
||||
let primitives = composer.primitives_ty;
|
||||
let top_level = Arc::new(composer.make_top_level_context());
|
||||
@ -298,11 +297,7 @@ fn test_simple_call() {
|
||||
loc: None,
|
||||
})));
|
||||
|
||||
let resolver = Resolver {
|
||||
id_to_type: HashMap::new(),
|
||||
id_to_def: RwLock::new(HashMap::new()),
|
||||
class_names: HashMap::default(),
|
||||
};
|
||||
let resolver = Resolver { id_to_type: HashMap::new(), id_to_def: RwLock::new(HashMap::new()) };
|
||||
resolver.add_id_def("foo".into(), DefinitionId(foo_id));
|
||||
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
@ -314,7 +309,7 @@ fn test_simple_call() {
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
|
||||
let threads = vec![DefaultCodeGenerator::new("test".into(), context.i64_type()).into()];
|
||||
let mut function_data = FunctionData {
|
||||
resolver: resolver.clone(),
|
||||
bound_variables: Vec::new(),
|
||||
@ -446,31 +441,34 @@ fn test_simple_call() {
|
||||
#[test]
|
||||
fn test_classes_list_type_new() {
|
||||
let ctx = inkwell::context::Context::create();
|
||||
let generator = DefaultCodeGenerator::new(String::new(), 64);
|
||||
let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type());
|
||||
|
||||
let llvm_i32 = ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(&ctx);
|
||||
|
||||
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
|
||||
assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok());
|
||||
let llvm_list = ListType::new_with_generator(&generator, &ctx, llvm_i32.into());
|
||||
assert!(ListType::is_representable(llvm_list.as_abi_type(), llvm_usize).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classes_range_type_new() {
|
||||
let ctx = inkwell::context::Context::create();
|
||||
let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type());
|
||||
|
||||
let llvm_range = RangeType::new(&ctx);
|
||||
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
|
||||
let llvm_usize = generator.get_size_type(&ctx);
|
||||
|
||||
let llvm_range = RangeType::new_with_generator(&generator, &ctx);
|
||||
assert!(RangeType::is_representable(llvm_range.as_abi_type(), llvm_usize).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_classes_ndarray_type_new() {
|
||||
let ctx = inkwell::context::Context::create();
|
||||
let generator = DefaultCodeGenerator::new(String::new(), 64);
|
||||
let generator = DefaultCodeGenerator::new(String::new(), ctx.i64_type());
|
||||
|
||||
let llvm_i32 = ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(&ctx);
|
||||
|
||||
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
|
||||
assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
|
||||
let llvm_ndarray = NDArrayType::new_with_generator(&generator, &ctx, llvm_i32.into(), 2);
|
||||
assert!(NDArrayType::is_representable(llvm_ndarray.as_abi_type(), llvm_usize).is_ok());
|
||||
}
|
||||
|
390
nac3core/src/codegen/types/list.rs
Normal file
390
nac3core/src/codegen/types/list.rs
Normal file
@ -0,0 +1,390 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use super::ProxyType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
types::structure::{
|
||||
check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields,
|
||||
StructProxyType,
|
||||
},
|
||||
values::ListValue,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
||||
};
|
||||
|
||||
/// Proxy type for a `list` type in LLVM.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct ListType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
item: Option<BasicTypeEnum<'ctx>>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct ListStructFields<'ctx> {
|
||||
/// Array pointer to content.
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
pub items: StructField<'ctx, PointerValue<'ctx>>,
|
||||
|
||||
/// Number of items in the array.
|
||||
#[value_type(usize)]
|
||||
pub len: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> ListStructFields<'ctx> {
|
||||
#[must_use]
|
||||
pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let mut counter = FieldIndexCounter::default();
|
||||
|
||||
ListStructFields {
|
||||
items: StructField::create(
|
||||
&mut counter,
|
||||
"items",
|
||||
item.ptr_type(AddressSpace::default()),
|
||||
),
|
||||
len: StructField::create(&mut counter, "len", llvm_usize),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ListType<'ctx> {
|
||||
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||
#[must_use]
|
||||
fn fields(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> ListStructFields<'ctx> {
|
||||
ListStructFields::new_typed(item, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of a `List`.
|
||||
#[must_use]
|
||||
fn llvm_type(
|
||||
ctx: &'ctx Context,
|
||||
element_type: Option<BasicTypeEnum<'ctx>>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> PointerType<'ctx> {
|
||||
let element_type = element_type.map_or(llvm_usize.into(), |ty| ty.as_basic_type_enum());
|
||||
|
||||
let field_tys =
|
||||
Self::fields(element_type, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(
|
||||
ctx: &'ctx Context,
|
||||
element_type: Option<BasicTypeEnum<'ctx>>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize);
|
||||
|
||||
Self { ty: llvm_list, item: element_type, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ListType`].
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>, element_type: &impl BasicType<'ctx>) -> Self {
|
||||
Self::new_impl(ctx.ctx, Some(element_type.as_basic_type_enum()), ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ListType`].
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
element_type: BasicTypeEnum<'ctx>,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, Some(element_type.as_basic_type_enum()), generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ListType`] with an unknown element type.
|
||||
#[must_use]
|
||||
pub fn new_untyped(ctx: &CodeGenContext<'ctx, '_>) -> Self {
|
||||
Self::new_impl(ctx.ctx, None, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ListType`] with an unknown element type.
|
||||
#[must_use]
|
||||
pub fn new_untyped_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, None, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an [`ListType`] from a [unifier type][Type].
|
||||
#[must_use]
|
||||
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
) -> Self {
|
||||
// Check unifier type and extract `item_type`
|
||||
let elem_type = match &*ctx.unifier.get_ty_immutable(ty) {
|
||||
TypeEnum::TObj { obj_id, params, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
iter_type_vars(params).next().unwrap().ty
|
||||
}
|
||||
|
||||
_ => panic!("Expected `list` type, but got {}", ctx.unifier.stringify(ty)),
|
||||
};
|
||||
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_elem_type = if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) {
|
||||
None
|
||||
} else {
|
||||
Some(ctx.get_llvm_type(generator, elem_type))
|
||||
};
|
||||
|
||||
Self::new_impl(ctx.ctx, llvm_elem_type, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`ListType`] from a [`StructType`].
|
||||
#[must_use]
|
||||
pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`ListType`] from a [`PointerType`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
let ctx = ptr_ty.get_context();
|
||||
|
||||
// We are just searching for the index off a field - Slot an arbitrary element type in.
|
||||
let item_field_idx =
|
||||
Self::fields(ctx.i8_type().into(), llvm_usize).index_of_field(|f| f.items);
|
||||
let item = unsafe {
|
||||
ptr_ty
|
||||
.get_element_type()
|
||||
.into_struct_type()
|
||||
.get_field_type_at_index_unchecked(item_field_idx)
|
||||
.into_pointer_type()
|
||||
.get_element_type()
|
||||
};
|
||||
let item = BasicTypeEnum::try_from(item).unwrap_or_else(|()| {
|
||||
panic!(
|
||||
"Expected BasicTypeEnum for list element type, got {}",
|
||||
ptr_ty.get_element_type().print_to_string()
|
||||
)
|
||||
});
|
||||
|
||||
ListType { ty: ptr_ty, item: Some(item), llvm_usize }
|
||||
}
|
||||
|
||||
/// Returns the type of the `size` field of this `list` type.
|
||||
#[must_use]
|
||||
pub fn size_type(&self) -> IntType<'ctx> {
|
||||
self.llvm_usize
|
||||
}
|
||||
|
||||
/// Returns the element type of this `list` type.
|
||||
#[must_use]
|
||||
pub fn element_type(&self) -> Option<BasicTypeEnum<'ctx>> {
|
||||
self.item
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca`].
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca_var`].
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates a [`ListValue`] on the stack using `item` of this [`ListType`] instance.
|
||||
///
|
||||
/// The returned list will contain:
|
||||
///
|
||||
/// - `data`: Allocated with `len` number of elements.
|
||||
/// - `len`: Initialized to the value of `len` passed to this function.
|
||||
#[must_use]
|
||||
pub fn construct<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
len: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let len = ctx.builder.build_int_z_extend(len, self.llvm_usize, "").unwrap();
|
||||
|
||||
// Generate a runtime assertion if allocating a non-empty list with unknown element type
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None && self.item.is_none() {
|
||||
let len_eqz = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, len, self.llvm_usize.const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
len_eqz,
|
||||
"0:AssertionError",
|
||||
"Cannot allocate a non-empty list with unknown element type",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let plist = self.alloca_var(generator, ctx, name);
|
||||
plist.store_size(ctx, len);
|
||||
|
||||
let item = self.item.unwrap_or(self.llvm_usize.into());
|
||||
plist.create_data(ctx, item, None);
|
||||
|
||||
plist
|
||||
}
|
||||
|
||||
/// Convenience function for creating a list with zero elements.
|
||||
///
|
||||
/// This function is preferred over [`ListType::construct`] if the length is known to always be
|
||||
/// 0, as this function avoids injecting an IR assertion for checking if a non-empty untyped
|
||||
/// list is being allocated.
|
||||
///
|
||||
/// The returned list will contain:
|
||||
///
|
||||
/// - `data`: Initialized to `(T*) 0`.
|
||||
/// - `len`: Initialized to `0`.
|
||||
#[must_use]
|
||||
pub fn construct_empty<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let plist = self.alloca_var(generator, ctx, name);
|
||||
|
||||
plist.store_size(ctx, self.llvm_usize.const_zero());
|
||||
plist.create_data(ctx, self.item.unwrap_or(self.llvm_usize.into()), None);
|
||||
|
||||
plist
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ListValue`].
|
||||
#[must_use]
|
||||
pub fn map_struct_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: StructValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ListValue`].
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = ListValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
let ctx = ty.get_context();
|
||||
|
||||
let llvm_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else {
|
||||
return Err(format!("Expected struct type for `list` type, got {llvm_ty}"));
|
||||
};
|
||||
|
||||
let fields = ListStructFields::new(ctx, llvm_usize);
|
||||
|
||||
check_struct_type_matches_fields(
|
||||
fields,
|
||||
llvm_ty,
|
||||
"list",
|
||||
&[(fields.items.name(), &|ty| {
|
||||
if ty.is_pointer_type() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Expected T* for `list.items`, got {ty}"))
|
||||
}
|
||||
})],
|
||||
)
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyType<'ctx> for ListType<'ctx> {
|
||||
type StructFields = ListStructFields<'ctx>;
|
||||
|
||||
fn get_fields(&self) -> Self::StructFields {
|
||||
Self::fields(self.item.unwrap_or(self.llvm_usize.into()), self.llvm_usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<ListType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: ListType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
129
nac3core/src/codegen/types/mod.rs
Normal file
129
nac3core/src/codegen/types/mod.rs
Normal file
@ -0,0 +1,129 @@
|
||||
//! This module contains abstraction over all intrinsic composite types of NAC3.
|
||||
//!
|
||||
//! # `raw_alloca` vs `alloca` vs `construct`
|
||||
//!
|
||||
//! There are three ways of creating a new object instance using the abstractions provided by this
|
||||
//! module.
|
||||
//!
|
||||
//! - `raw_alloca`: Allocates the object on the stack, returning an instance of
|
||||
//! [`impl BasicValue`][inkwell::values::BasicValue]. This is similar to a `malloc` expression in
|
||||
//! C++ but the object is allocated on the stack.
|
||||
//! - `alloca`: Similar to `raw_alloca`, but also wraps the allocated object with
|
||||
//! [`<Self as ProxyType<'ctx>>::Value`][ProxyValue], and returns the wrapped object. The returned
|
||||
//! object will not initialize any value or fields. This is similar to a type-safe `malloc`
|
||||
//! expression in C++ but the object is allocated on the stack.
|
||||
//! - `construct`: Similar to `alloca`, but performs some initialization on the value or fields of
|
||||
//! the returned object. This is similar to a `new` expression in C++ but the object is allocated
|
||||
//! on the stack.
|
||||
|
||||
use inkwell::{
|
||||
types::{BasicType, IntType},
|
||||
values::{IntValue, PointerValue},
|
||||
};
|
||||
|
||||
use super::{
|
||||
values::{ArraySliceValue, ProxyValue},
|
||||
{CodeGenContext, CodeGenerator},
|
||||
};
|
||||
pub use list::*;
|
||||
pub use range::*;
|
||||
pub use tuple::*;
|
||||
|
||||
mod list;
|
||||
pub mod ndarray;
|
||||
mod range;
|
||||
pub mod structure;
|
||||
mod tuple;
|
||||
pub mod utils;
|
||||
|
||||
/// A LLVM type that is used to represent a corresponding type in NAC3.
|
||||
pub trait ProxyType<'ctx>: Into<Self::Base> {
|
||||
/// The ABI type of which values of this type possess.
|
||||
type ABI: BasicType<'ctx>;
|
||||
|
||||
/// The LLVM type of which values of this type possess.
|
||||
type Base: BasicType<'ctx>;
|
||||
|
||||
/// The type of values represented by this type.
|
||||
type Value: ProxyValue<'ctx, Type = Self>;
|
||||
|
||||
/// Checks whether `llvm_ty` can be represented by this [`ProxyType`].
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String>;
|
||||
|
||||
/// Checks whether the type represented by `ty` expresses the same type represented by this
|
||||
/// [`ProxyType`].
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String>;
|
||||
|
||||
/// Returns the type that should be used in `alloca` IR statements.
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx>;
|
||||
|
||||
/// Creates a new value of this type by invoking `alloca` at the current builder location,
|
||||
/// returning a [`PointerValue`] instance representing the allocated value.
|
||||
fn raw_alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
ctx.builder
|
||||
.build_alloca(self.alloca_type().as_basic_type_enum(), name.unwrap_or_default())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Creates a new value of this type by invoking `alloca` at the beginning of the function,
|
||||
/// returning a [`PointerValue`] instance representing the allocated value.
|
||||
fn raw_alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
generator.gen_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), name).unwrap()
|
||||
}
|
||||
|
||||
/// Creates a new array value of this type by invoking `alloca` at the current builder location,
|
||||
/// returning an [`ArraySliceValue`] encapsulating the resulting array.
|
||||
fn array_alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> ArraySliceValue<'ctx> {
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder
|
||||
.build_array_alloca(
|
||||
self.alloca_type().as_basic_type_enum(),
|
||||
size,
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
.unwrap(),
|
||||
size,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a new array value of this type by invoking `alloca` at the beginning of the
|
||||
/// function, returning an [`ArraySliceValue`] encapsulating the resulting array.
|
||||
fn array_alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> ArraySliceValue<'ctx> {
|
||||
generator
|
||||
.gen_array_var_alloc(ctx, self.alloca_type().as_basic_type_enum(), size, name)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Returns the [base type][Self::Base] of this proxy.
|
||||
fn as_base_type(&self) -> Self::Base;
|
||||
|
||||
/// Returns this proxy as its ABI type, i.e. the expected type representation if a value of this
|
||||
/// [`ProxyType`] is being passed into or returned from a function.
|
||||
///
|
||||
/// See [`CodeGenContext::get_llvm_abi_type`].
|
||||
fn as_abi_type(&self) -> Self::ABI;
|
||||
}
|
240
nac3core/src/codegen/types/ndarray/array.rs
Normal file
240
nac3core/src/codegen/types/ndarray/array.rs
Normal file
@ -0,0 +1,240 @@
|
||||
use inkwell::{
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValueEnum, IntValue},
|
||||
AddressSpace,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt,
|
||||
stmt::gen_if_else_expr_callback,
|
||||
types::{ndarray::NDArrayType, ListType, ProxyType},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ListValue, ProxyValue,
|
||||
TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
/// Get the expected `dtype` and `ndims` of the ndarray returned by `np_array(<list>)`.
|
||||
fn get_list_object_dtype_and_ndims<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
list_ty: Type,
|
||||
) -> (BasicTypeEnum<'ctx>, u64) {
|
||||
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list_ty);
|
||||
let ndims = arraylike_get_ndims(&mut ctx.unifier, list_ty);
|
||||
|
||||
(ctx.get_llvm_type(generator, dtype), ndims)
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Implementation of `np_array(<list>, copy=True)`
|
||||
fn construct_numpy_array_from_list_copy_true_impl<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(list_ty, list): (Type, ListValue<'ctx>),
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(generator, ctx, list_ty);
|
||||
assert!(self.ndims >= ndims_int);
|
||||
assert_eq!(dtype, self.dtype);
|
||||
|
||||
let list_value = list.as_i8_list(ctx);
|
||||
|
||||
// Validate `list` has a consistent shape.
|
||||
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
|
||||
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
|
||||
let ndims = self.llvm_usize.const_int(ndims_int, false);
|
||||
let shape = ctx.builder.build_array_alloca(self.llvm_usize, ndims, "").unwrap();
|
||||
let shape = ArraySliceValue::from_ptr_val(shape, ndims, None);
|
||||
let shape = TypedArrayLikeAdapter::from(
|
||||
shape,
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
irrt::ndarray::call_nac3_ndarray_array_set_and_validate_list_shape(
|
||||
generator, ctx, list_value, ndims, &shape,
|
||||
);
|
||||
|
||||
let ndarray =
|
||||
Self::new(ctx, dtype, ndims_int).construct_uninitialized(generator, ctx, name);
|
||||
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||
unsafe { ndarray.create_data(generator, ctx) };
|
||||
|
||||
// Copy all contents from the list.
|
||||
irrt::ndarray::call_nac3_ndarray_array_write_list_to_array(ctx, list_value, ndarray);
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Implementation of `np_array(<list>, copy=None)`
|
||||
fn construct_numpy_array_from_list_copy_none_impl<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(list_ty, list): (Type, ListValue<'ctx>),
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
// np_array without copying is only possible `list` is not nested.
|
||||
//
|
||||
// If `list` is `list[T]`, we can create an ndarray with `data` set
|
||||
// to the array pointer of `list`.
|
||||
//
|
||||
// If `list` is `list[list[T]]` or worse, copy.
|
||||
|
||||
let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty);
|
||||
if ndims == 1 {
|
||||
// `list` is not nested
|
||||
assert_eq!(ndims, 1);
|
||||
assert!(self.ndims >= ndims);
|
||||
assert_eq!(dtype, self.dtype);
|
||||
|
||||
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray = Self::new(ctx, dtype, 1).construct_uninitialized(generator, ctx, name);
|
||||
|
||||
// Set data
|
||||
let data = ctx
|
||||
.builder
|
||||
.build_pointer_cast(list.data().base_ptr(ctx, generator), llvm_pi8, "")
|
||||
.unwrap();
|
||||
ndarray.store_data(ctx, data);
|
||||
|
||||
// ndarray->shape[0] = list->len;
|
||||
let shape = ndarray.shape();
|
||||
let list_len = list.load_size(ctx, None);
|
||||
unsafe {
|
||||
shape.set_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), list_len);
|
||||
}
|
||||
|
||||
// Set strides, the `data` is contiguous
|
||||
ndarray.set_strides_contiguous(ctx);
|
||||
|
||||
ndarray
|
||||
} else {
|
||||
// `list` is nested, copy
|
||||
self.construct_numpy_array_from_list_copy_true_impl(
|
||||
generator,
|
||||
ctx,
|
||||
(list_ty, list),
|
||||
name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of `np_array(<list>, copy=copy)`
|
||||
fn construct_numpy_array_list_impl<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(list_ty, list): (Type, ListValue<'ctx>),
|
||||
copy: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(copy.get_type(), ctx.ctx.bool_type());
|
||||
|
||||
let (dtype, ndims) = get_list_object_dtype_and_ndims(generator, ctx, list_ty);
|
||||
|
||||
let ndarray = gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_generator, _ctx| Ok(copy),
|
||||
|generator, ctx| {
|
||||
let ndarray = self.construct_numpy_array_from_list_copy_true_impl(
|
||||
generator,
|
||||
ctx,
|
||||
(list_ty, list),
|
||||
name,
|
||||
);
|
||||
Ok(Some(ndarray.as_abi_value(ctx)))
|
||||
},
|
||||
|generator, ctx| {
|
||||
let ndarray = self.construct_numpy_array_from_list_copy_none_impl(
|
||||
generator,
|
||||
ctx,
|
||||
(list_ty, list),
|
||||
name,
|
||||
);
|
||||
Ok(Some(ndarray.as_abi_value(ctx)))
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
|
||||
NDArrayType::new(ctx, dtype, ndims).map_pointer_value(ndarray, None)
|
||||
}
|
||||
|
||||
/// Implementation of `np_array(<ndarray>, copy=copy)`.
|
||||
pub fn construct_numpy_array_ndarray_impl<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
copy: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(ndarray.get_type().dtype, self.dtype);
|
||||
assert!(self.ndims >= ndarray.get_type().ndims);
|
||||
assert_eq!(copy.get_type(), ctx.ctx.bool_type());
|
||||
|
||||
let ndarray_val = gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_generator, _ctx| Ok(copy),
|
||||
|generator, ctx| {
|
||||
let ndarray = ndarray.make_copy(generator, ctx); // Force copy
|
||||
Ok(Some(ndarray.as_abi_value(ctx)))
|
||||
},
|
||||
|_generator, ctx| {
|
||||
// No need to copy. Return `ndarray` itself.
|
||||
Ok(Some(ndarray.as_abi_value(ctx)))
|
||||
},
|
||||
)
|
||||
.unwrap()
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
|
||||
ndarray.get_type().map_pointer_value(ndarray_val, name)
|
||||
}
|
||||
|
||||
/// Create a new ndarray like
|
||||
/// [`np.array()`](https://numpy.org/doc/stable/reference/generated/numpy.array.html).
|
||||
///
|
||||
/// Note that the returned [`NDArrayValue`] may have fewer dimensions than is specified by this
|
||||
/// instance. Use [`NDArrayValue::atleast_nd`] on the returned value if an `ndarray` instance
|
||||
/// with the exact number of dimensions is needed.
|
||||
pub fn construct_numpy_array<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(object_ty, object): (Type, BasicValueEnum<'ctx>),
|
||||
copy: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
match &*ctx.unifier.get_ty_immutable(object_ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let list = ListType::from_unifier_type(generator, ctx, object_ty)
|
||||
.map_pointer_value(object.into_pointer_value(), None);
|
||||
self.construct_numpy_array_list_impl(generator, ctx, (object_ty, list), copy, name)
|
||||
}
|
||||
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let ndarray = NDArrayType::from_unifier_type(generator, ctx, object_ty)
|
||||
.map_pointer_value(object.into_pointer_value(), None);
|
||||
self.construct_numpy_array_ndarray_impl(generator, ctx, ndarray, copy, name)
|
||||
}
|
||||
|
||||
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this
|
||||
}
|
||||
}
|
||||
}
|
205
nac3core/src/codegen/types/ndarray/broadcast.rs
Normal file
205
nac3core/src/codegen/types/ndarray/broadcast.rs
Normal file
@ -0,0 +1,205 @@
|
||||
use inkwell::{
|
||||
context::{AsContextRef, Context},
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use crate::codegen::{
|
||||
types::{
|
||||
structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType},
|
||||
ProxyType,
|
||||
},
|
||||
values::ndarray::ShapeEntryValue,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct ShapeEntryType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct ShapeEntryStructFields<'ctx> {
|
||||
#[value_type(usize)]
|
||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> ShapeEntryType<'ctx> {
|
||||
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||
#[must_use]
|
||||
fn fields(
|
||||
ctx: impl AsContextRef<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> ShapeEntryStructFields<'ctx> {
|
||||
ShapeEntryStructFields::new(ctx, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of a `ShapeEntry`.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||
let field_tys =
|
||||
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let llvm_ty = Self::llvm_type(ctx, llvm_usize);
|
||||
|
||||
Self { ty: llvm_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ShapeEntryType`].
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self {
|
||||
Self::new_impl(ctx.ctx, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ShapeEntryType`].
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates a [`ShapeEntryType`] from a [`StructType`] representing an `ShapeEntry`.
|
||||
#[must_use]
|
||||
pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates a [`ShapeEntryType`] from a [`PointerType`] representing an `ShapeEntry`.
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
Self { ty: ptr_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type.
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ShapeEntryValue`] as if by calling `alloca` on the base type.
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ShapeEntryValue`].
|
||||
#[must_use]
|
||||
pub fn map_struct_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: StructValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ShapeEntryValue`].
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for ShapeEntryType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = ShapeEntryValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
let ctx = ty.get_context();
|
||||
|
||||
let llvm_ndarray_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||
return Err(format!(
|
||||
"Expected struct type for `ShapeEntry` type, got {llvm_ndarray_ty}"
|
||||
));
|
||||
};
|
||||
|
||||
check_struct_type_matches_fields(
|
||||
Self::fields(ctx, llvm_usize),
|
||||
llvm_ndarray_ty,
|
||||
"NDArray",
|
||||
&[],
|
||||
)
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyType<'ctx> for ShapeEntryType<'ctx> {
|
||||
type StructFields = ShapeEntryStructFields<'ctx>;
|
||||
|
||||
fn get_fields(&self) -> Self::StructFields {
|
||||
Self::fields(self.ty.get_context(), self.llvm_usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<ShapeEntryType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: ShapeEntryType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
281
nac3core/src/codegen/types/ndarray/contiguous.rs
Normal file
281
nac3core/src/codegen/types/ndarray/contiguous.rs
Normal file
@ -0,0 +1,281 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
types::{
|
||||
structure::{
|
||||
check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields,
|
||||
StructProxyType,
|
||||
},
|
||||
ProxyType,
|
||||
},
|
||||
values::ndarray::ContiguousNDArrayValue,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
toplevel::numpy::unpack_ndarray_var_tys,
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct ContiguousNDArrayType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
item: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct ContiguousNDArrayStructFields<'ctx> {
|
||||
#[value_type(usize)]
|
||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> ContiguousNDArrayStructFields<'ctx> {
|
||||
#[must_use]
|
||||
pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let mut counter = FieldIndexCounter::default();
|
||||
|
||||
ContiguousNDArrayStructFields {
|
||||
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
||||
shape: StructField::create(
|
||||
&mut counter,
|
||||
"shape",
|
||||
llvm_usize.ptr_type(AddressSpace::default()),
|
||||
),
|
||||
data: StructField::create(&mut counter, "data", item.ptr_type(AddressSpace::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ContiguousNDArrayType<'ctx> {
|
||||
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||
#[must_use]
|
||||
fn fields(
|
||||
item: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> ContiguousNDArrayStructFields<'ctx> {
|
||||
ContiguousNDArrayStructFields::new_typed(item, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||
#[must_use]
|
||||
fn llvm_type(
|
||||
ctx: &'ctx Context,
|
||||
item: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> PointerType<'ctx> {
|
||||
let field_tys =
|
||||
Self::fields(item, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(ctx: &'ctx Context, item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize);
|
||||
|
||||
Self { ty: llvm_cndarray, item, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ContiguousNDArrayType`].
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>, item: &impl BasicType<'ctx>) -> Self {
|
||||
Self::new_impl(ctx.ctx, item.as_basic_type_enum(), ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`ContiguousNDArrayType`].
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
item: BasicTypeEnum<'ctx>,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, item, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type].
|
||||
#[must_use]
|
||||
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
) -> Self {
|
||||
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
||||
|
||||
Self::new_impl(ctx.ctx, llvm_dtype, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an [`ContiguousNDArrayType`] from a [`StructType`] representing an `NDArray`.
|
||||
#[must_use]
|
||||
pub fn from_struct_type(
|
||||
ty: StructType<'ctx>,
|
||||
item: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), item, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(
|
||||
ptr_ty: PointerType<'ctx>,
|
||||
item: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
Self { ty: ptr_ty, item, llvm_usize }
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base
|
||||
/// type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca`].
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.item,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base
|
||||
/// type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca_var`].
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.item,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ContiguousNDArrayValue`].
|
||||
#[must_use]
|
||||
pub fn map_struct_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: StructValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
self.item,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ContiguousNDArrayValue`].
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
value,
|
||||
self.item,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = ContiguousNDArrayValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
let ctx = ty.get_context();
|
||||
|
||||
let llvm_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else {
|
||||
return Err(format!(
|
||||
"Expected struct type for `ContiguousNDArray` type, got {llvm_ty}"
|
||||
));
|
||||
};
|
||||
|
||||
let fields = ContiguousNDArrayStructFields::new(ctx, llvm_usize);
|
||||
|
||||
check_struct_type_matches_fields(
|
||||
fields,
|
||||
llvm_ty,
|
||||
"ContiguousNDArray",
|
||||
&[(fields.data.name(), &|ty| {
|
||||
if ty.is_pointer_type() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}"))
|
||||
}
|
||||
})],
|
||||
)
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyType<'ctx> for ContiguousNDArrayType<'ctx> {
|
||||
type StructFields = ContiguousNDArrayStructFields<'ctx>;
|
||||
|
||||
fn get_fields(&self) -> Self::StructFields {
|
||||
Self::fields(self.item, self.llvm_usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<ContiguousNDArrayType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: ContiguousNDArrayType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
236
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
236
nac3core/src/codegen/types/ndarray/factory.rs
Normal file
@ -0,0 +1,236 @@
|
||||
use inkwell::{
|
||||
values::{BasicValueEnum, IntValue},
|
||||
IntPredicate,
|
||||
};
|
||||
|
||||
use super::NDArrayType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt, types::ProxyType, values::TypedArrayLikeAccessor, CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
/// Get the zero value in `np.zeros()` of a `dtype`.
|
||||
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i32_type().const_zero().into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
ctx.ctx.i64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").into()
|
||||
} else {
|
||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the one value in `np.ones()` of a `dtype`.
|
||||
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
|
||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64]
|
||||
.iter()
|
||||
.any(|ty| ctx.unifier.unioned(dtype, *ty))
|
||||
{
|
||||
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
|
||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_float(1.0).into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_int(1, false).into()
|
||||
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").into()
|
||||
} else {
|
||||
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Create an ndarray like
|
||||
/// [`np.empty`](https://numpy.org/doc/stable/reference/generated/numpy.empty.html).
|
||||
pub fn construct_numpy_empty<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.construct_uninitialized(generator, ctx, name);
|
||||
|
||||
// Validate `shape`
|
||||
irrt::ndarray::call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, shape);
|
||||
|
||||
ndarray.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||
unsafe { ndarray.create_data(generator, ctx) };
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.full`](https://numpy.org/doc/stable/reference/generated/numpy.full.html).
|
||||
pub fn construct_numpy_full<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.construct_numpy_empty(generator, ctx, shape, name);
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.zero`](https://numpy.org/doc/stable/reference/generated/numpy.zeros.html).
|
||||
pub fn construct_numpy_zeros<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(
|
||||
ctx.get_llvm_type(generator, dtype),
|
||||
self.dtype,
|
||||
"Expected LLVM dtype={} but got {}",
|
||||
self.dtype.print_to_string(),
|
||||
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||
);
|
||||
|
||||
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
||||
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.ones`](https://numpy.org/doc/stable/reference/generated/numpy.ones.html).
|
||||
pub fn construct_numpy_ones<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(
|
||||
ctx.get_llvm_type(generator, dtype),
|
||||
self.dtype,
|
||||
"Expected LLVM dtype={} but got {}",
|
||||
self.dtype.print_to_string(),
|
||||
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||
);
|
||||
|
||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||
self.construct_numpy_full(generator, ctx, shape, fill_value, name)
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.eye`](https://numpy.org/doc/stable/reference/generated/numpy.eye.html).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn construct_numpy_eye<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
nrows: IntValue<'ctx>,
|
||||
ncols: IntValue<'ctx>,
|
||||
offset: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(
|
||||
ctx.get_llvm_type(generator, dtype),
|
||||
self.dtype,
|
||||
"Expected LLVM dtype={} but got {}",
|
||||
self.dtype.print_to_string(),
|
||||
ctx.get_llvm_type(generator, dtype).print_to_string(),
|
||||
);
|
||||
assert_eq!(nrows.get_type(), self.llvm_usize);
|
||||
assert_eq!(ncols.get_type(), self.llvm_usize);
|
||||
assert_eq!(offset.get_type(), self.llvm_usize);
|
||||
|
||||
let ndzero = ndarray_zero_value(generator, ctx, dtype);
|
||||
let ndone = ndarray_one_value(generator, ctx, dtype);
|
||||
|
||||
let ndarray = self.construct_dyn_shape(generator, ctx, &[nrows, ncols], name);
|
||||
|
||||
// Create data and make the matrix like look np.eye()
|
||||
unsafe {
|
||||
ndarray.create_data(generator, ctx);
|
||||
}
|
||||
ndarray
|
||||
.foreach(generator, ctx, |generator, ctx, _, nditer| {
|
||||
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
|
||||
// and this loop would not execute.
|
||||
|
||||
let indices = nditer.get_indices();
|
||||
|
||||
let row_i = unsafe {
|
||||
indices.get_typed_unchecked(ctx, generator, &self.llvm_usize.const_zero(), None)
|
||||
};
|
||||
let col_i = unsafe {
|
||||
indices.get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&self.llvm_usize.const_int(1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let be_one = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
ctx.builder.build_int_add(row_i, offset, "").unwrap(),
|
||||
col_i,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let value = ctx.builder.build_select(be_one, ndone, ndzero, "value").unwrap();
|
||||
|
||||
let p = nditer.get_pointer(ctx);
|
||||
ctx.builder.build_store(p, value).unwrap();
|
||||
|
||||
Ok(())
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an ndarray like
|
||||
/// [`np.identity`](https://numpy.org/doc/stable/reference/generated/numpy.identity.html).
|
||||
pub fn construct_numpy_identity<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dtype: Type,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let offset = self.llvm_usize.const_zero();
|
||||
self.construct_numpy_eye(generator, ctx, dtype, size, size, offset, name)
|
||||
}
|
||||
}
|
233
nac3core/src/codegen/types/ndarray/indexing.rs
Normal file
233
nac3core/src/codegen/types/ndarray/indexing.rs
Normal file
@ -0,0 +1,233 @@
|
||||
use inkwell::{
|
||||
context::{AsContextRef, Context},
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use crate::codegen::{
|
||||
types::{
|
||||
structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType},
|
||||
ProxyType,
|
||||
},
|
||||
values::{
|
||||
ndarray::{NDIndexValue, RustNDIndex},
|
||||
ArrayLikeIndexer, ArraySliceValue,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct NDIndexType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct NDIndexStructFields<'ctx> {
|
||||
#[value_type(i8_type())]
|
||||
pub type_: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> NDIndexType<'ctx> {
|
||||
#[must_use]
|
||||
fn fields(
|
||||
ctx: impl AsContextRef<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> NDIndexStructFields<'ctx> {
|
||||
NDIndexStructFields::new(ctx, llvm_usize)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||
let field_tys =
|
||||
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let llvm_ndindex = Self::llvm_type(ctx, llvm_usize);
|
||||
|
||||
Self { ty: llvm_ndindex, llvm_usize }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self {
|
||||
Self::new_impl(ctx.ctx, ctx.get_size_type())
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
Self { ty: ptr_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca`].
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
/// Allocates an instance of [`NDIndexValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca_var`].
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Serialize a list of [`RustNDIndex`] as a newly allocated LLVM array of [`NDIndexValue`].
|
||||
#[must_use]
|
||||
pub fn construct_ndindices<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
in_ndindices: &[RustNDIndex<'ctx>],
|
||||
) -> ArraySliceValue<'ctx> {
|
||||
// Allocate the LLVM ndindices.
|
||||
let num_ndindices = self.llvm_usize.const_int(in_ndindices.len() as u64, false);
|
||||
let ndindices = self.array_alloca_var(generator, ctx, num_ndindices, None);
|
||||
|
||||
// Initialize all of them.
|
||||
for (i, in_ndindex) in in_ndindices.iter().enumerate() {
|
||||
let pndindex = unsafe {
|
||||
ndindices.ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&ctx.ctx.i64_type().const_int(u64::try_from(i).unwrap(), false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
in_ndindex.write_to_ndindex(
|
||||
generator,
|
||||
ctx,
|
||||
NDIndexValue::from_pointer_value(pndindex, self.llvm_usize, None),
|
||||
);
|
||||
}
|
||||
|
||||
ndindices
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn map_struct_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: StructValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = NDIndexValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
let ctx = ty.get_context();
|
||||
|
||||
let llvm_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else {
|
||||
return Err(format!(
|
||||
"Expected struct type for `ContiguousNDArray` type, got {llvm_ty}"
|
||||
));
|
||||
};
|
||||
|
||||
let fields = NDIndexStructFields::new(ctx, llvm_usize);
|
||||
|
||||
check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[])
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyType<'ctx> for NDIndexType<'ctx> {
|
||||
type StructFields = NDIndexStructFields<'ctx>;
|
||||
|
||||
fn get_fields(&self) -> Self::StructFields {
|
||||
Self::fields(self.ty.get_context(), self.llvm_usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<NDIndexType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: NDIndexType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
183
nac3core/src/codegen/types/ndarray/map.rs
Normal file
183
nac3core/src/codegen/types/ndarray/map.rs
Normal file
@ -0,0 +1,183 @@
|
||||
use inkwell::{types::BasicTypeEnum, values::BasicValueEnum};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::{
|
||||
stmt::gen_for_callback,
|
||||
types::{
|
||||
ndarray::{NDArrayType, NDIterType},
|
||||
ProxyType,
|
||||
},
|
||||
values::{
|
||||
ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray},
|
||||
ArrayLikeValue, ProxyValue,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Generate LLVM IR to broadcast `ndarray`s together, and starmap through them with `mapping`
|
||||
/// elementwise.
|
||||
///
|
||||
/// `mapping` is an LLVM IR generator. The input of `mapping` is the list of elements when
|
||||
/// iterating through the input `ndarrays` after broadcasting. The output of `mapping` is the
|
||||
/// result of the elementwise operation.
|
||||
///
|
||||
/// `out` specifies whether the result should be a new ndarray or to be written an existing
|
||||
/// ndarray.
|
||||
pub fn broadcast_starmap<'a, G, MappingFn>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarrays: &[NDArrayValue<'ctx>],
|
||||
out: NDArrayOut<'ctx>,
|
||||
mapping: MappingFn,
|
||||
) -> Result<<Self as ProxyType<'ctx>>::Value, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
&[BasicValueEnum<'ctx>],
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
// Broadcast inputs
|
||||
let broadcast_result = self.broadcast(generator, ctx, ndarrays);
|
||||
|
||||
let out_ndarray = match out {
|
||||
NDArrayOut::NewNDArray { dtype } => {
|
||||
// Create a new ndarray based on the broadcast shape.
|
||||
let result_ndarray = NDArrayType::new(ctx, dtype, broadcast_result.ndims)
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
result_ndarray.copy_shape_from_array(
|
||||
generator,
|
||||
ctx,
|
||||
broadcast_result.shape.base_ptr(ctx, generator),
|
||||
);
|
||||
unsafe {
|
||||
result_ndarray.create_data(generator, ctx);
|
||||
}
|
||||
result_ndarray
|
||||
}
|
||||
|
||||
NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => {
|
||||
// Use an existing ndarray.
|
||||
|
||||
// Check that its shape is compatible with the broadcast shape.
|
||||
result_ndarray.assert_can_be_written_by_out(generator, ctx, broadcast_result.shape);
|
||||
result_ndarray
|
||||
}
|
||||
};
|
||||
|
||||
// Map element-wise and store results into `mapped_ndarray`.
|
||||
let nditer = NDIterType::new(ctx).construct(generator, ctx, out_ndarray);
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
Some("broadcast_starmap"),
|
||||
|generator, ctx| {
|
||||
// Create NDIters for all broadcasted input ndarrays.
|
||||
let other_nditers = broadcast_result
|
||||
.ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| NDIterType::new(ctx).construct(generator, ctx, *ndarray))
|
||||
.collect_vec();
|
||||
Ok((nditer, other_nditers))
|
||||
},
|
||||
|_, ctx, (out_nditer, _in_nditers)| {
|
||||
// We can simply use `out_nditer`'s `has_element()`.
|
||||
// `in_nditers`' `has_element()`s should return the same value.
|
||||
Ok(out_nditer.has_element(ctx))
|
||||
},
|
||||
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
|
||||
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
|
||||
// and write to `out_ndarray`.
|
||||
let in_scalars =
|
||||
in_nditers.iter().map(|nditer| nditer.get_scalar(ctx)).collect_vec();
|
||||
|
||||
let result = mapping(generator, ctx, &in_scalars)?;
|
||||
|
||||
let p = out_nditer.get_pointer(ctx);
|
||||
ctx.builder.build_store(p, result).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, (out_nditer, in_nditers)| {
|
||||
// Advance all iterators
|
||||
out_nditer.next(ctx);
|
||||
in_nditers.iter().for_each(|nditer| nditer.next(ctx));
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(out_ndarray)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// Starmap through a list of inputs using `mapping`, where an input could be an ndarray, a
|
||||
/// scalar.
|
||||
///
|
||||
/// This function is very helpful when implementing NumPy functions that takes on either scalars
|
||||
/// or ndarrays or a mix of them as their inputs and produces either an ndarray with broadcast,
|
||||
/// or a scalar if all its inputs are all scalars.
|
||||
///
|
||||
/// For example ,this function can be used to implement `np.add`, which has the following
|
||||
/// behaviors:
|
||||
///
|
||||
/// - `np.add(3, 4) = 7` # (scalar, scalar) -> scalar
|
||||
/// - `np.add(3, np.array([4, 5, 6]))` # (scalar, ndarray) -> ndarray; the first `scalar` is
|
||||
/// converted into an ndarray and broadcasted.
|
||||
/// - `np.add(np.array([[1], [2], [3]]), np.array([[4, 5, 6]]))` # (ndarray, ndarray) ->
|
||||
/// ndarray; there is broadcasting.
|
||||
///
|
||||
/// ## Details:
|
||||
///
|
||||
/// If `inputs` are all [`ScalarOrNDArray::Scalar`], the output will be a
|
||||
/// [`ScalarOrNDArray::Scalar`] with type `ret_dtype`.
|
||||
///
|
||||
/// Otherwise (if there are any [`ScalarOrNDArray::NDArray`] in `inputs`), all inputs will be
|
||||
/// 'as-ndarray'-ed into ndarrays, then all inputs (now all ndarrays) will be passed to
|
||||
/// [`NDArrayValue::broadcasting_starmap`] and **create** a new ndarray with dtype `ret_dtype`.
|
||||
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
inputs: &[ScalarOrNDArray<'ctx>],
|
||||
ret_dtype: BasicTypeEnum<'ctx>,
|
||||
mapping: MappingFn,
|
||||
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
&[BasicValueEnum<'ctx>],
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
// Check if all inputs are Scalars
|
||||
let all_scalars: Option<Vec<_>> =
|
||||
inputs.iter().map(BasicValueEnum::<'ctx>::try_from).try_collect().ok();
|
||||
|
||||
if let Some(scalars) = all_scalars {
|
||||
let scalars = scalars.iter().copied().collect_vec();
|
||||
let value = mapping(generator, ctx, &scalars)?;
|
||||
|
||||
Ok(ScalarOrNDArray::Scalar(value))
|
||||
} else {
|
||||
// Promote all input to ndarrays and map through them.
|
||||
let inputs = inputs.iter().map(|input| input.to_ndarray(generator, ctx)).collect_vec();
|
||||
let ndarray = NDArrayType::new_broadcast(
|
||||
ctx,
|
||||
ret_dtype,
|
||||
&inputs.iter().map(NDArrayValue::get_type).collect_vec(),
|
||||
)
|
||||
.broadcast_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
&inputs,
|
||||
NDArrayOut::NewNDArray { dtype: ret_dtype },
|
||||
mapping,
|
||||
)?;
|
||||
Ok(ScalarOrNDArray::NDArray(ndarray))
|
||||
}
|
||||
}
|
||||
}
|
510
nac3core/src/codegen/types/ndarray/mod.rs
Normal file
510
nac3core/src/codegen/types/ndarray/mod.rs
Normal file
@ -0,0 +1,510 @@
|
||||
use inkwell::{
|
||||
context::{AsContextRef, Context},
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{BasicValue, IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use super::{
|
||||
structure::{check_struct_type_matches_fields, StructField, StructFields, StructProxyType},
|
||||
ProxyType,
|
||||
};
|
||||
use crate::{
|
||||
codegen::{
|
||||
values::{ndarray::NDArrayValue, TypedArrayLikeMutator},
|
||||
{CodeGenContext, CodeGenerator},
|
||||
},
|
||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
pub use broadcast::*;
|
||||
pub use contiguous::*;
|
||||
pub use indexing::*;
|
||||
pub use nditer::*;
|
||||
|
||||
mod array;
|
||||
mod broadcast;
|
||||
mod contiguous;
|
||||
pub mod factory;
|
||||
mod indexing;
|
||||
mod map;
|
||||
mod nditer;
|
||||
|
||||
/// Proxy type for a `ndarray` type in LLVM.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct NDArrayType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
ndims: u64,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct NDArrayStructFields<'ctx> {
|
||||
/// The size of each `NDArray` element in bytes.
|
||||
#[value_type(usize)]
|
||||
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
||||
/// Number of dimensions in the array.
|
||||
#[value_type(usize)]
|
||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
/// Pointer to an array containing the shape of the `NDArray`.
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
/// Pointer to an array indicating the number of bytes between each element at a dimension
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
||||
/// Pointer to an array containing the array data
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||
#[must_use]
|
||||
fn fields(
|
||||
ctx: impl AsContextRef<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> NDArrayStructFields<'ctx> {
|
||||
NDArrayStructFields::new(ctx, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||
let field_tys =
|
||||
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(
|
||||
ctx: &'ctx Context,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
ndims: u64,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
|
||||
|
||||
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`].
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>, ndims: u64) -> Self {
|
||||
Self::new_impl(ctx.ctx, dtype, ndims, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`].
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
ndims: u64,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, dtype, ndims, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more
|
||||
/// `ndarray` operands.
|
||||
#[must_use]
|
||||
pub fn new_broadcast(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
inputs: &[NDArrayType<'ctx>],
|
||||
) -> Self {
|
||||
assert!(!inputs.is_empty());
|
||||
|
||||
Self::new_impl(
|
||||
ctx.ctx,
|
||||
dtype,
|
||||
inputs.iter().map(NDArrayType::ndims).max().unwrap(),
|
||||
ctx.get_size_type(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`] as a result of a broadcast operation over one or more
|
||||
/// `ndarray` operands.
|
||||
#[must_use]
|
||||
pub fn new_broadcast_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
inputs: &[NDArrayType<'ctx>],
|
||||
) -> Self {
|
||||
assert!(!inputs.is_empty());
|
||||
|
||||
Self::new_impl(
|
||||
ctx,
|
||||
dtype,
|
||||
inputs.iter().map(NDArrayType::ndims).max().unwrap(),
|
||||
generator.get_size_type(ctx),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
|
||||
#[must_use]
|
||||
pub fn new_unsized(ctx: &CodeGenContext<'ctx, '_>, dtype: BasicTypeEnum<'ctx>) -> Self {
|
||||
Self::new_impl(ctx.ctx, dtype, 0, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
|
||||
#[must_use]
|
||||
pub fn new_unsized_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, dtype, 0, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an [`NDArrayType`] from a [unifier type][Type].
|
||||
#[must_use]
|
||||
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
) -> Self {
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
|
||||
Self::new_impl(ctx.ctx, llvm_dtype, ndims, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an [`NDArrayType`] from a [`StructType`] representing an `NDArray`.
|
||||
#[must_use]
|
||||
pub fn from_struct_type(
|
||||
ty: StructType<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
ndims: u64,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), dtype, ndims, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(
|
||||
ptr_ty: PointerType<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
ndims: u64,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
NDArrayType { ty: ptr_ty, dtype, ndims, llvm_usize }
|
||||
}
|
||||
|
||||
/// Returns the type of the `size` field of this `ndarray` type.
|
||||
#[must_use]
|
||||
pub fn size_type(&self) -> IntType<'ctx> {
|
||||
self.llvm_usize
|
||||
}
|
||||
|
||||
/// Returns the element type of this `ndarray` type.
|
||||
#[must_use]
|
||||
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
/// Returns the number of dimensions of this `ndarray` type.
|
||||
#[must_use]
|
||||
pub fn ndims(&self) -> u64 {
|
||||
self.ndims
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca`].
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.dtype,
|
||||
self.ndims,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca_var`].
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.dtype,
|
||||
self.ndims,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an [`NDArrayValue`] on the stack and initializes all fields as follows:
|
||||
///
|
||||
/// - `data`: uninitialized.
|
||||
/// - `itemsize`: set to the size of `self.dtype`.
|
||||
/// - `ndims`: set to the value of `ndims`.
|
||||
/// - `shape`: allocated on the stack with an array of length `ndims` with uninitialized values.
|
||||
/// - `strides`: allocated on the stack with an array of length `ndims` with uninitialized
|
||||
/// values.
|
||||
#[must_use]
|
||||
fn construct_impl<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndims: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndarray = self.alloca_var(generator, ctx, name);
|
||||
|
||||
let itemsize = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
|
||||
.unwrap();
|
||||
ndarray.store_itemsize(ctx, itemsize);
|
||||
|
||||
ndarray.store_ndims(ctx, ndims);
|
||||
|
||||
ndarray.create_shape(ctx, self.llvm_usize, ndims);
|
||||
ndarray.create_strides(ctx, self.llvm_usize, ndims);
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Allocate an [`NDArrayValue`] on the stack using `dtype` and `ndims` of this [`NDArrayType`]
|
||||
/// instance.
|
||||
///
|
||||
/// The returned ndarray's content will be:
|
||||
/// - `data`: uninitialized.
|
||||
/// - `itemsize`: set to the size of `dtype`.
|
||||
/// - `ndims`: set to the value of `self.ndims`.
|
||||
/// - `shape`: allocated on the stack with an array of length `ndims` with uninitialized values.
|
||||
/// - `strides`: allocated on the stack with an array of length `ndims` with uninitialized
|
||||
/// values.
|
||||
#[must_use]
|
||||
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let ndims = self.llvm_usize.const_int(self.ndims, false);
|
||||
|
||||
self.construct_impl(generator, ctx, ndims, name)
|
||||
}
|
||||
|
||||
/// Convenience function. Allocate an [`NDArrayValue`] with a statically known shape.
|
||||
///
|
||||
/// The returned [`NDArrayValue`]'s `data` and `strides` are uninitialized.
|
||||
#[must_use]
|
||||
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &[u64],
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(shape.len() as u64, self.ndims);
|
||||
|
||||
let ndarray = Self::new(ctx, self.dtype, shape.len() as u64)
|
||||
.construct_uninitialized(generator, ctx, name);
|
||||
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Write shape
|
||||
let ndarray_shape = ndarray.shape();
|
||||
for (i, dim) in shape.iter().enumerate() {
|
||||
let dim = llvm_usize.const_int(*dim, false);
|
||||
unsafe {
|
||||
ndarray_shape.set_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
dim,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Convenience function. Allocate an [`NDArrayValue`] with a dynamically known shape.
|
||||
///
|
||||
/// The returned [`NDArrayValue`]'s `data` and `strides` are uninitialized.
|
||||
#[must_use]
|
||||
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
shape: &[IntValue<'ctx>],
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
assert_eq!(shape.len() as u64, self.ndims);
|
||||
|
||||
let ndarray = Self::new(ctx, self.dtype, shape.len() as u64)
|
||||
.construct_uninitialized(generator, ctx, name);
|
||||
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Write shape
|
||||
let ndarray_shape = ndarray.shape();
|
||||
for (i, dim) in shape.iter().enumerate() {
|
||||
assert_eq!(
|
||||
dim.get_type(),
|
||||
llvm_usize,
|
||||
"Expected {} but got {}",
|
||||
llvm_usize.print_to_string(),
|
||||
dim.get_type().print_to_string()
|
||||
);
|
||||
unsafe {
|
||||
ndarray_shape.set_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
*dim,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Create an unsized ndarray to contain `value`.
|
||||
#[must_use]
|
||||
pub fn construct_unsized<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: &impl BasicValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> NDArrayValue<'ctx> {
|
||||
let value = value.as_basic_value_enum();
|
||||
|
||||
assert_eq!(value.get_type(), self.dtype);
|
||||
assert_eq!(self.ndims, 0);
|
||||
|
||||
// We have to put the value on the stack to get a data pointer.
|
||||
let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap();
|
||||
ctx.builder.build_store(data, value).unwrap();
|
||||
let data = ctx
|
||||
.builder
|
||||
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
||||
.unwrap();
|
||||
|
||||
let ndarray =
|
||||
Self::new_unsized(ctx, value.get_type()).construct_uninitialized(generator, ctx, name);
|
||||
ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap();
|
||||
ndarray
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`NDArrayValue`].
|
||||
#[must_use]
|
||||
pub fn map_struct_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: StructValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
self.dtype,
|
||||
self.ndims,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`NDArrayValue`].
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
value,
|
||||
self.dtype,
|
||||
self.ndims,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = NDArrayValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
let ctx = ty.get_context();
|
||||
|
||||
let llvm_ndarray_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
||||
};
|
||||
|
||||
check_struct_type_matches_fields(
|
||||
Self::fields(ctx, llvm_usize),
|
||||
llvm_ndarray_ty,
|
||||
"NDArray",
|
||||
&[],
|
||||
)
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyType<'ctx> for NDArrayType<'ctx> {
|
||||
type StructFields = NDArrayStructFields<'ctx>;
|
||||
|
||||
fn get_fields(&self) -> Self::StructFields {
|
||||
Self::fields(self.ty.get_context(), self.llvm_usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: NDArrayType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
267
nac3core/src/codegen/types/ndarray/nditer.rs
Normal file
267
nac3core/src/codegen/types/ndarray/nditer.rs
Normal file
@ -0,0 +1,267 @@
|
||||
use inkwell::{
|
||||
context::{AsContextRef, Context},
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use super::ProxyType;
|
||||
use crate::codegen::{
|
||||
irrt,
|
||||
types::structure::{
|
||||
check_struct_type_matches_fields, StructField, StructFields, StructProxyType,
|
||||
},
|
||||
values::{
|
||||
ndarray::{NDArrayValue, NDIterValue},
|
||||
ArrayLikeValue, ArraySliceValue, ProxyValue, TypedArrayLikeAdapter,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct NDIterType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct NDIterStructFields<'ctx> {
|
||||
#[value_type(usize)]
|
||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||
pub indices: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
pub nth: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||
pub element: StructField<'ctx, PointerValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
pub size: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> NDIterType<'ctx> {
|
||||
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||
#[must_use]
|
||||
fn fields(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> NDIterStructFields<'ctx> {
|
||||
NDIterStructFields::new(ctx, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of an `NDIter`.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
|
||||
let field_tys =
|
||||
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let llvm_nditer = Self::llvm_type(ctx, llvm_usize);
|
||||
|
||||
Self { ty: llvm_nditer, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDIter`].
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self {
|
||||
Self::new_impl(ctx.ctx, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`NDIter`].
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an [`NDIterType`] from a [`StructType`] representing an `NDIter`.
|
||||
#[must_use]
|
||||
pub fn from_struct_type(ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`.
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
Self { ty: ptr_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Returns the type of the `size` field of this `nditer` type.
|
||||
#[must_use]
|
||||
pub fn size_type(&self) -> IntType<'ctx> {
|
||||
self.llvm_usize
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca`].
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
parent: NDArrayValue<'ctx>,
|
||||
indices: ArraySliceValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
parent,
|
||||
indices,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`NDIterValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca_var`].
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
parent: NDArrayValue<'ctx>,
|
||||
indices: ArraySliceValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
parent,
|
||||
indices,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocate an [`NDIter`] that iterates through the given `ndarray`.
|
||||
#[must_use]
|
||||
pub fn construct<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let nditer = self.raw_alloca_var(generator, ctx, None);
|
||||
let ndims = self.llvm_usize.const_int(ndarray.get_type().ndims(), false);
|
||||
|
||||
// The caller has the responsibility to allocate 'indices' for `NDIter`.
|
||||
let indices =
|
||||
generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap();
|
||||
let indices =
|
||||
TypedArrayLikeAdapter::from(indices, |_, _, v| v.into_int_value(), |_, _, v| v.into());
|
||||
|
||||
let nditer =
|
||||
self.map_pointer_value(nditer, ndarray, indices.as_slice_value(ctx, generator), None);
|
||||
|
||||
irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, &indices);
|
||||
|
||||
nditer
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn map_struct_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: StructValue<'ctx>,
|
||||
parent: NDArrayValue<'ctx>,
|
||||
indices: ArraySliceValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
parent,
|
||||
indices,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
parent: NDArrayValue<'ctx>,
|
||||
indices: ArraySliceValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
value,
|
||||
parent,
|
||||
indices,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = NDIterValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
let ctx = ty.get_context();
|
||||
|
||||
let llvm_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else {
|
||||
return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}"));
|
||||
};
|
||||
|
||||
check_struct_type_matches_fields(
|
||||
Self::fields(ctx, llvm_usize),
|
||||
llvm_ndarray_ty,
|
||||
"NDIter",
|
||||
&[],
|
||||
)
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyType<'ctx> for NDIterType<'ctx> {
|
||||
type StructFields = NDIterStructFields<'ctx>;
|
||||
|
||||
fn get_fields(&self) -> Self::StructFields {
|
||||
Self::fields(self.ty.get_context(), self.llvm_usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<NDIterType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: NDIterType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
208
nac3core/src/codegen/types/range.rs
Normal file
208
nac3core/src/codegen/types/range.rs
Normal file
@ -0,0 +1,208 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{AnyTypeEnum, ArrayType, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||
values::{ArrayValue, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
|
||||
use super::ProxyType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
values::RangeValue,
|
||||
{CodeGenContext, CodeGenerator},
|
||||
},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
/// Proxy type for a `range` type in LLVM.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct RangeType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> RangeType<'ctx> {
|
||||
/// Creates an LLVM type corresponding to the expected structure of a `Range`.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> {
|
||||
// typedef int32_t Range[3];
|
||||
let llvm_i32 = ctx.i32_type();
|
||||
llvm_i32.array_type(3).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let llvm_range = Self::llvm_type(ctx);
|
||||
|
||||
RangeType { ty: llvm_range, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`RangeType`].
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>) -> Self {
|
||||
Self::new_impl(ctx.ctx, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`RangeType`].
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an [`RangeType`] from a [unifier type][Type].
|
||||
#[must_use]
|
||||
pub fn from_unifier_type(ctx: &mut CodeGenContext<'ctx, '_>, ty: Type) -> Self {
|
||||
// Check unifier type
|
||||
assert!(
|
||||
matches!(&*ctx.unifier.get_ty_immutable(ty), TypeEnum::TObj { obj_id, .. } if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap())
|
||||
);
|
||||
|
||||
Self::new(ctx)
|
||||
}
|
||||
|
||||
/// Creates an [`RangeType`] from a [`ArrayType`].
|
||||
#[must_use]
|
||||
pub fn from_array_type(arr_ty: ArrayType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
Self::from_pointer_type(arr_ty.ptr_type(AddressSpace::default()), llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`RangeType`] from a [`PointerType`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, llvm_usize).is_ok());
|
||||
|
||||
RangeType { ty: ptr_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Returns the type of all fields of this `range` type.
|
||||
#[must_use]
|
||||
pub fn value_type(&self) -> IntType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_array_type().get_element_type().into_int_type()
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca`].
|
||||
#[must_use]
|
||||
pub fn alloca<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca_var`].
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`RangeValue`].
|
||||
#[must_use]
|
||||
pub fn map_array_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: ArrayValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_array_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`RangeValue`].
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = RangeValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, _: IntType<'ctx>) -> Result<(), String> {
|
||||
let llvm_range_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
|
||||
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"));
|
||||
};
|
||||
if llvm_range_ty.len() != 3 {
|
||||
return Err(format!(
|
||||
"Expected 3 elements for `range` type, got {}",
|
||||
llvm_range_ty.len()
|
||||
));
|
||||
}
|
||||
|
||||
let llvm_range_elem_ty = llvm_range_ty.get_element_type();
|
||||
let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else {
|
||||
return Err(format!(
|
||||
"Expected int type for `range` element type, got {llvm_range_elem_ty}"
|
||||
));
|
||||
};
|
||||
if llvm_range_elem_ty.get_bit_width() != 32 {
|
||||
return Err(format!(
|
||||
"Expected 32-bit int type for `range` element type, got {}",
|
||||
llvm_range_elem_ty.get_bit_width()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<RangeType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: RangeType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
333
nac3core/src/codegen/types/structure.rs
Normal file
333
nac3core/src/codegen/types/structure.rs
Normal file
@ -0,0 +1,333 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use inkwell::{
|
||||
context::AsContextRef,
|
||||
types::{BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{AggregateValueEnum, BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::ProxyType;
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
/// A LLVM type that is used to represent a corresponding structure-like type in NAC3.
|
||||
pub trait StructProxyType<'ctx>: ProxyType<'ctx, Base = PointerType<'ctx>> {
|
||||
/// The concrete type of [`StructFields`].
|
||||
type StructFields: StructFields<'ctx>;
|
||||
|
||||
/// Whether this [`StructProxyType`] has the same LLVM type representation as
|
||||
/// [`llvm_ty`][StructType].
|
||||
fn has_same_struct_repr(
|
||||
llvm_ty: StructType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
Self::has_same_pointer_repr(llvm_ty.ptr_type(AddressSpace::default()), llvm_usize)
|
||||
}
|
||||
|
||||
/// Whether this [`StructProxyType`] has the same LLVM type representation as
|
||||
/// [`llvm_ty`][PointerType].
|
||||
fn has_same_pointer_repr(
|
||||
llvm_ty: PointerType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
Self::has_same_repr(llvm_ty, llvm_usize)
|
||||
}
|
||||
|
||||
/// Returns the fields present in this [`StructProxyType`].
|
||||
#[must_use]
|
||||
fn get_fields(&self) -> Self::StructFields;
|
||||
|
||||
/// Returns the [`StructType`].
|
||||
#[must_use]
|
||||
fn get_struct_type(&self) -> StructType<'ctx> {
|
||||
self.as_base_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
/// Returns the [`PointerType`] representing this type.
|
||||
#[must_use]
|
||||
fn get_pointer_type(&self) -> PointerType<'ctx> {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// For example, for a simple C-slice LLVM structure:
|
||||
///
|
||||
/// ```ignore
|
||||
/// struct CSliceFields<'ctx> {
|
||||
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
/// len: StructField<'ctx, IntValue<'ctx>>
|
||||
/// }
|
||||
/// ```
|
||||
pub trait StructFields<'ctx>: Eq + Copy {
|
||||
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
|
||||
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self;
|
||||
|
||||
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
||||
/// the type definition.
|
||||
#[must_use]
|
||||
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
|
||||
|
||||
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
||||
/// in the type definition.
|
||||
#[must_use]
|
||||
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
|
||||
self.to_vec().into_iter()
|
||||
}
|
||||
|
||||
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
||||
/// the type definition.
|
||||
#[must_use]
|
||||
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.to_vec()
|
||||
}
|
||||
|
||||
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
||||
/// in the type definition.
|
||||
#[must_use]
|
||||
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.into_vec().into_iter()
|
||||
}
|
||||
|
||||
/// Returns the field index of a field in this structure.
|
||||
fn index_of_field<V>(&self, name: impl FnOnce(&Self) -> StructField<'ctx, V>) -> u32
|
||||
where
|
||||
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||
{
|
||||
let field_name = name(self).name;
|
||||
self.index_of_field_name(field_name).unwrap()
|
||||
}
|
||||
|
||||
/// Returns the field index of a field with the given name in this structure.
|
||||
fn index_of_field_name(&self, field_name: &str) -> Option<u32> {
|
||||
self.iter().find_position(|(name, _)| *name == field_name).map(|(idx, _)| idx as u32)
|
||||
}
|
||||
}
|
||||
|
||||
/// A single field of an LLVM structure.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct StructField<'ctx, Value>
|
||||
where
|
||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||
{
|
||||
/// The index of this field within the structure.
|
||||
index: u32,
|
||||
|
||||
/// The name of this field.
|
||||
name: &'static str,
|
||||
|
||||
/// The type of this field.
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
|
||||
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
|
||||
_value_ty: PhantomData<Value>,
|
||||
}
|
||||
|
||||
impl<'ctx, Value> StructField<'ctx, Value>
|
||||
where
|
||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||
{
|
||||
/// Creates an instance of [`StructField`].
|
||||
///
|
||||
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
|
||||
/// index.
|
||||
/// * `name` - Name of the field.
|
||||
/// * `ty` - The type of this field.
|
||||
pub fn create(
|
||||
idx_counter: &mut FieldIndexCounter,
|
||||
name: &'static str,
|
||||
ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||
) -> Self {
|
||||
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`StructField`] with a given index.
|
||||
///
|
||||
/// * `index` - The index of this field within its enclosing structure.
|
||||
/// * `name` - Name of the field.
|
||||
/// * `ty` - The type of this field.
|
||||
pub fn create_at(index: u32, name: &'static str, ty: impl Into<BasicTypeEnum<'ctx>>) -> Self {
|
||||
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
|
||||
}
|
||||
|
||||
/// Returns the name of this field.
|
||||
#[must_use]
|
||||
pub fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
||||
/// {idx...}, i32 {self.index}`.
|
||||
pub fn ptr_by_array_gep(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
idx: &[IntValue<'ctx>],
|
||||
) -> PointerValue<'ctx> {
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
pobj,
|
||||
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
|
||||
"",
|
||||
)
|
||||
}
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
|
||||
/// `getelementptr i32 0, i32 {self.index}`.
|
||||
pub fn ptr_by_gep(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
obj_name: Option<&'ctx str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
ctx.builder
|
||||
.build_struct_gep(
|
||||
pobj,
|
||||
self.index,
|
||||
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Gets the value of this field for a given `obj`.
|
||||
#[must_use]
|
||||
pub fn extract_value(&self, ctx: &CodeGenContext<'ctx, '_>, obj: StructValue<'ctx>) -> Value {
|
||||
Value::try_from(
|
||||
ctx.builder
|
||||
.build_extract_value(
|
||||
obj,
|
||||
self.index,
|
||||
&format!("{}.{}", obj.get_name().to_str().unwrap(), self.name),
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Sets the value of this field for a given `obj`.
|
||||
#[must_use]
|
||||
pub fn insert_value(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
obj: StructValue<'ctx>,
|
||||
value: Value,
|
||||
) -> StructValue<'ctx> {
|
||||
let obj_name = obj.get_name().to_str().unwrap();
|
||||
let new_obj_name = if obj_name.chars().all(char::is_numeric) { "" } else { obj_name };
|
||||
|
||||
ctx.builder
|
||||
.build_insert_value(obj, value, self.index, new_obj_name)
|
||||
.map(AggregateValueEnum::into_struct_value)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Loads the value of this field for a pointer-to-structure.
|
||||
pub fn load(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
obj_name: Option<&'ctx str>,
|
||||
) -> Value {
|
||||
ctx.builder
|
||||
.build_load(
|
||||
self.ptr_by_gep(ctx, pobj, obj_name),
|
||||
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
.and_then(|value| Value::try_from(value))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Stores the value of this field for a pointer-to-structure.
|
||||
pub fn store(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
value: Value,
|
||||
obj_name: Option<&'ctx str>,
|
||||
) {
|
||||
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
|
||||
where
|
||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||
{
|
||||
fn from(value: StructField<'ctx, Value>) -> Self {
|
||||
(value.name, value.ty)
|
||||
}
|
||||
}
|
||||
|
||||
/// A counter that tracks the next index of a field using a monotonically increasing counter.
|
||||
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct FieldIndexCounter(u32);
|
||||
|
||||
impl FieldIndexCounter {
|
||||
/// Increments the number stored by this counter, returning the previous value.
|
||||
///
|
||||
/// Functionally equivalent to `i++` in C-based languages.
|
||||
pub fn increment(&mut self) -> u32 {
|
||||
let v = self.0;
|
||||
self.0 += 1;
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
type FieldTypeVerifier<'ctx> = dyn Fn(BasicTypeEnum<'ctx>) -> Result<(), String>;
|
||||
|
||||
/// Checks whether [`llvm_ty`][StructType] contains the fields described by the given
|
||||
/// [`StructFields`] instance.
|
||||
///
|
||||
/// By default, this function will compare the type of each field in `expected_fields` against
|
||||
/// `llvm_ty`. To override this behavior for individual fields, pass in overrides to
|
||||
/// `custom_verifiers`, which will use the specified verifier when a field with the matching field
|
||||
/// name is being checked.
|
||||
pub(super) fn check_struct_type_matches_fields<'ctx>(
|
||||
expected_fields: impl StructFields<'ctx>,
|
||||
llvm_ty: StructType<'ctx>,
|
||||
ty_name: &'static str,
|
||||
custom_verifiers: &[(&str, &FieldTypeVerifier<'ctx>)],
|
||||
) -> Result<(), String> {
|
||||
let expected_fields = expected_fields.to_vec();
|
||||
|
||||
if llvm_ty.count_fields() != u32::try_from(expected_fields.len()).unwrap() {
|
||||
return Err(format!(
|
||||
"Expected {} fields in `{ty_name}`, got {}",
|
||||
expected_fields.len(),
|
||||
llvm_ty.count_fields(),
|
||||
));
|
||||
}
|
||||
|
||||
expected_fields
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, (field_name, expected_ty))| {
|
||||
(field_name, expected_ty, llvm_ty.get_field_type_at_index(i as u32).unwrap())
|
||||
})
|
||||
.try_for_each(|(field_name, expected_ty, actual_ty)| {
|
||||
if let Some((_, verifier)) =
|
||||
custom_verifiers.iter().find(|verifier| verifier.0 == field_name)
|
||||
{
|
||||
verifier(actual_ty)
|
||||
} else if expected_ty == actual_ty {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Expected {expected_ty} for `{ty_name}.{field_name}`, got {actual_ty}"))
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
213
nac3core/src/codegen/types/tuple.rs
Normal file
213
nac3core/src/codegen/types/tuple.rs
Normal file
@ -0,0 +1,213 @@
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{BasicValueEnum, PointerValue, StructValue},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::ProxyType;
|
||||
use crate::{
|
||||
codegen::{values::TupleValue, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct TupleType<'ctx> {
|
||||
ty: StructType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> TupleType<'ctx> {
|
||||
/// Creates an LLVM type corresponding to the expected structure of a tuple.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, tys: &[BasicTypeEnum<'ctx>]) -> StructType<'ctx> {
|
||||
ctx.struct_type(tys, false)
|
||||
}
|
||||
|
||||
fn new_impl(
|
||||
ctx: &'ctx Context,
|
||||
tys: &[BasicTypeEnum<'ctx>],
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
let llvm_tuple = Self::llvm_type(ctx, tys);
|
||||
|
||||
Self { ty: llvm_tuple, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`TupleType`].
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>, tys: &[impl BasicType<'ctx>]) -> Self {
|
||||
Self::new_impl(
|
||||
ctx.ctx,
|
||||
&tys.iter().map(BasicType::as_basic_type_enum).collect_vec(),
|
||||
ctx.get_size_type(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates an instance of [`TupleType`].
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
tys: &[BasicTypeEnum<'ctx>],
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, tys, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an [`TupleType`] from a [unifier type][Type].
|
||||
#[must_use]
|
||||
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ty: Type,
|
||||
) -> Self {
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Sanity check on object type.
|
||||
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty_immutable(ty) else {
|
||||
panic!("Expected type to be a TypeEnum::TTuple, got {}", ctx.unifier.stringify(ty));
|
||||
};
|
||||
|
||||
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
|
||||
Self { ty: Self::llvm_type(ctx.ctx, &llvm_tys), llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an [`TupleType`] from a [`StructType`].
|
||||
#[must_use]
|
||||
pub fn from_struct_type(struct_ty: StructType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
debug_assert!(Self::has_same_repr(struct_ty, llvm_usize).is_ok());
|
||||
|
||||
TupleType { ty: struct_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an [`TupleType`] from a [`PointerType`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
Self::from_struct_type(ptr_ty.get_element_type().into_struct_type(), llvm_usize)
|
||||
}
|
||||
|
||||
/// Returns the number of elements present in this [`TupleType`].
|
||||
#[must_use]
|
||||
pub fn num_elements(&self) -> u32 {
|
||||
self.ty.count_fields()
|
||||
}
|
||||
|
||||
/// Returns the type of the tuple element at the given `index`, or [`None`] if `index` is out of
|
||||
/// range.
|
||||
#[must_use]
|
||||
pub fn type_at_index(&self, index: u32) -> Option<BasicTypeEnum<'ctx>> {
|
||||
if index < self.num_elements() {
|
||||
Some(unsafe { self.type_at_index_unchecked(index) })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the type of the tuple element at the given `index`.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The caller must ensure that the index is valid.
|
||||
#[must_use]
|
||||
pub unsafe fn type_at_index_unchecked(&self, index: u32) -> BasicTypeEnum<'ctx> {
|
||||
self.ty.get_field_type_at_index_unchecked(index)
|
||||
}
|
||||
|
||||
/// Constructs a [`TupleValue`] from this type by zero-initializing the tuple value.
|
||||
#[must_use]
|
||||
pub fn construct(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
self.map_struct_value(
|
||||
Self::llvm_type(ctx.ctx, &self.ty.get_field_types()).const_zero(),
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Constructs a [`TupleValue`] from `objects`. The resulting tuple preserves the order of
|
||||
/// objects.
|
||||
#[must_use]
|
||||
pub fn construct_from_objects<I: IntoIterator<Item = BasicValueEnum<'ctx>>>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
objects: I,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
let values = objects.into_iter().collect_vec();
|
||||
|
||||
assert_eq!(values.len(), self.num_elements() as usize);
|
||||
assert!(values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, v)| { v.get_type() == unsafe { self.type_at_index_unchecked(i as u32) } }));
|
||||
|
||||
let mut value = self.construct(ctx, name);
|
||||
for (i, val) in values.into_iter().enumerate() {
|
||||
value.store_element(ctx, i as u32, val);
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ListValue`].
|
||||
#[must_use]
|
||||
pub fn map_struct_value(
|
||||
&self,
|
||||
value: StructValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(value, self.llvm_usize, name)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`TupleValue`].
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(ctx, value, self.llvm_usize, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for TupleType<'ctx> {
|
||||
type ABI = StructType<'ctx>;
|
||||
type Base = StructType<'ctx>;
|
||||
type Value = TupleValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::StructType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected struct type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(_: Self::Base, _: IntType<'ctx>) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_base_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<TupleType<'ctx>> for StructType<'ctx> {
|
||||
fn from(value: TupleType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
3
nac3core/src/codegen/types/utils/mod.rs
Normal file
3
nac3core/src/codegen/types/utils/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
pub use slice::*;
|
||||
|
||||
mod slice;
|
291
nac3core/src/codegen/types/utils/slice.rs
Normal file
291
nac3core/src/codegen/types/utils/slice.rs
Normal file
@ -0,0 +1,291 @@
|
||||
use inkwell::{
|
||||
context::{AsContextRef, Context, ContextRef},
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core_derive::StructFields;
|
||||
|
||||
use crate::codegen::{
|
||||
types::{
|
||||
structure::{
|
||||
check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields,
|
||||
StructProxyType,
|
||||
},
|
||||
ProxyType,
|
||||
},
|
||||
values::utils::SliceValue,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct SliceType<'ctx> {
|
||||
ty: PointerType<'ctx>,
|
||||
int_ty: IntType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||
pub struct SliceStructFields<'ctx> {
|
||||
#[value_type(bool_type())]
|
||||
pub start_defined: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
pub start: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(bool_type())]
|
||||
pub stop_defined: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
pub stop: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(bool_type())]
|
||||
pub step_defined: StructField<'ctx, IntValue<'ctx>>,
|
||||
#[value_type(usize)]
|
||||
pub step: StructField<'ctx, IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> SliceStructFields<'ctx> {
|
||||
/// Creates a new instance of [`SliceStructFields`] with a custom integer type for its range values.
|
||||
#[must_use]
|
||||
pub fn new_sized(ctx: &impl AsContextRef<'ctx>, int_ty: IntType<'ctx>) -> Self {
|
||||
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
|
||||
let mut counter = FieldIndexCounter::default();
|
||||
|
||||
SliceStructFields {
|
||||
start_defined: StructField::create(&mut counter, "start_defined", ctx.bool_type()),
|
||||
start: StructField::create(&mut counter, "start", int_ty),
|
||||
stop_defined: StructField::create(&mut counter, "stop_defined", ctx.bool_type()),
|
||||
stop: StructField::create(&mut counter, "stop", int_ty),
|
||||
step_defined: StructField::create(&mut counter, "step_defined", ctx.bool_type()),
|
||||
step: StructField::create(&mut counter, "step", int_ty),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> SliceType<'ctx> {
|
||||
/// Creates an LLVM type corresponding to the expected structure of a `Slice`.
|
||||
#[must_use]
|
||||
fn llvm_type(ctx: &'ctx Context, int_ty: IntType<'ctx>) -> PointerType<'ctx> {
|
||||
let field_tys = SliceStructFields::new_sized(&int_ty.get_context(), int_ty)
|
||||
.into_iter()
|
||||
.map(|field| field.1)
|
||||
.collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
fn new_impl(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let llvm_ty = Self::llvm_type(ctx, int_ty);
|
||||
|
||||
Self { ty: llvm_ty, int_ty, llvm_usize }
|
||||
}
|
||||
|
||||
/// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type.
|
||||
#[must_use]
|
||||
pub fn new(ctx: &CodeGenContext<'ctx, '_>, int_ty: IntType<'ctx>) -> Self {
|
||||
Self::new_impl(ctx.ctx, int_ty, ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type.
|
||||
#[must_use]
|
||||
pub fn new_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
int_ty: IntType<'ctx>,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, int_ty, generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an instance of [`SliceType`] with `usize` as its backing integer type.
|
||||
#[must_use]
|
||||
pub fn new_usize(ctx: &CodeGenContext<'ctx, '_>) -> Self {
|
||||
Self::new_impl(ctx.ctx, ctx.get_size_type(), ctx.get_size_type())
|
||||
}
|
||||
|
||||
/// Creates an instance of [`SliceType`] with `usize` as its backing integer type.
|
||||
#[must_use]
|
||||
pub fn new_usize_with_generator<G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &'ctx Context,
|
||||
) -> Self {
|
||||
Self::new_impl(ctx, generator.get_size_type(ctx), generator.get_size_type(ctx))
|
||||
}
|
||||
|
||||
/// Creates an [`SliceType`] from a [`StructType`] representing a `slice`.
|
||||
#[must_use]
|
||||
pub fn from_struct_type(
|
||||
ty: StructType<'ctx>,
|
||||
int_ty: IntType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
Self::from_pointer_type(ty.ptr_type(AddressSpace::default()), int_ty, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`.
|
||||
#[must_use]
|
||||
pub fn from_pointer_type(
|
||||
ptr_ty: PointerType<'ctx>,
|
||||
int_ty: IntType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::has_same_repr(ptr_ty, int_ty).is_ok());
|
||||
|
||||
Self { ty: ptr_ty, int_ty, llvm_usize }
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn element_type(&self) -> IntType<'ctx> {
|
||||
self.int_ty
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca`].
|
||||
#[must_use]
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca(ctx, name),
|
||||
self.int_ty,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocates an instance of [`SliceValue`] as if by calling `alloca` on the base type.
|
||||
///
|
||||
/// See [`ProxyType::raw_alloca_var`].
|
||||
#[must_use]
|
||||
pub fn alloca_var<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
self.raw_alloca_var(generator, ctx, name),
|
||||
self.int_ty,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`SliceValue`].
|
||||
#[must_use]
|
||||
pub fn map_struct_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: StructValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_struct_value(
|
||||
generator,
|
||||
ctx,
|
||||
value,
|
||||
self.int_ty,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
|
||||
/// Converts an existing value into a [`ContiguousNDArrayValue`].
|
||||
#[must_use]
|
||||
pub fn map_pointer_value(
|
||||
&self,
|
||||
value: PointerValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> <Self as ProxyType<'ctx>>::Value {
|
||||
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||
value,
|
||||
self.int_ty,
|
||||
self.llvm_usize,
|
||||
name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> {
|
||||
type ABI = PointerType<'ctx>;
|
||||
type Base = PointerType<'ctx>;
|
||||
type Value = SliceValue<'ctx>;
|
||||
|
||||
fn is_representable(
|
||||
llvm_ty: impl BasicType<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||
Self::has_same_repr(ty, llvm_usize)
|
||||
} else {
|
||||
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||
}
|
||||
}
|
||||
|
||||
fn has_same_repr(ty: Self::Base, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
let ctx = ty.get_context();
|
||||
|
||||
let fields = SliceStructFields::new(ctx, llvm_usize);
|
||||
|
||||
let llvm_ty = ty.get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else {
|
||||
return Err(format!("Expected struct type for `Slice` type, got {llvm_ty}"));
|
||||
};
|
||||
|
||||
check_struct_type_matches_fields(
|
||||
fields,
|
||||
llvm_ty,
|
||||
"Slice",
|
||||
&[
|
||||
(fields.start.name(), &|ty| {
|
||||
if ty.is_int_type() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Expected int type for `Slice.start`, got {ty}"))
|
||||
}
|
||||
}),
|
||||
(fields.stop.name(), &|ty| {
|
||||
if ty.is_int_type() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Expected int type for `Slice.stop`, got {ty}"))
|
||||
}
|
||||
}),
|
||||
(fields.step.name(), &|ty| {
|
||||
if ty.is_int_type() {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("Expected int type for `Slice.step`, got {ty}"))
|
||||
}
|
||||
}),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
fn alloca_type(&self) -> impl BasicType<'ctx> {
|
||||
self.as_abi_type().get_element_type().into_struct_type()
|
||||
}
|
||||
|
||||
fn as_base_type(&self) -> Self::Base {
|
||||
self.ty
|
||||
}
|
||||
|
||||
fn as_abi_type(&self) -> Self::ABI {
|
||||
self.as_base_type()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyType<'ctx> for SliceType<'ctx> {
|
||||
type StructFields = SliceStructFields<'ctx>;
|
||||
|
||||
fn get_fields(&self) -> Self::StructFields {
|
||||
SliceStructFields::new_sized(&self.ty.get_context(), self.int_ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<SliceType<'ctx>> for PointerType<'ctx> {
|
||||
fn from(value: SliceType<'ctx>) -> Self {
|
||||
value.as_base_type()
|
||||
}
|
||||
}
|
439
nac3core/src/codegen/values/array.rs
Normal file
439
nac3core/src/codegen/values/array.rs
Normal file
@ -0,0 +1,439 @@
|
||||
use inkwell::{
|
||||
types::AnyTypeEnum,
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
IntPredicate,
|
||||
};
|
||||
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
|
||||
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
|
||||
/// elements.
|
||||
pub trait ArrayLikeValue<'ctx> {
|
||||
/// Returns the element type of this array-like value.
|
||||
fn element_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> AnyTypeEnum<'ctx>;
|
||||
|
||||
/// Returns the base pointer to the array.
|
||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> PointerValue<'ctx>;
|
||||
|
||||
/// Returns the size of this array-like value.
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> IntValue<'ctx>;
|
||||
|
||||
/// Returns a [`ArraySliceValue`] representing this value.
|
||||
fn as_slice_value<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
) -> ArraySliceValue<'ctx> {
|
||||
ArraySliceValue::from_ptr_val(
|
||||
self.base_ptr(ctx, generator),
|
||||
self.size(ctx, generator),
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// An array-like value that can be indexed by memory offset.
|
||||
pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
|
||||
/// # Safety
|
||||
///
|
||||
/// This function should be called with a valid index.
|
||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx>;
|
||||
|
||||
/// Returns the pointer to the data at the `idx`-th index.
|
||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx>;
|
||||
}
|
||||
|
||||
/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`].
|
||||
pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>:
|
||||
ArrayLikeIndexer<'ctx, Index>
|
||||
{
|
||||
/// # Safety
|
||||
///
|
||||
/// This function should be called with a valid index.
|
||||
unsafe fn get_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
||||
}
|
||||
|
||||
/// Returns the data at the `idx`-th index.
|
||||
fn get<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset(ctx, generator, idx, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`].
|
||||
pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>:
|
||||
ArrayLikeIndexer<'ctx, Index>
|
||||
{
|
||||
/// # Safety
|
||||
///
|
||||
/// This function should be called with a valid index.
|
||||
unsafe fn set_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
idx: &Index,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, None) };
|
||||
ctx.builder.build_store(ptr, value).unwrap();
|
||||
}
|
||||
|
||||
/// Sets the data at the `idx`-th index.
|
||||
fn set<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &Index,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
let ptr = self.ptr_offset(ctx, generator, idx, None);
|
||||
ctx.builder.build_store(ptr, value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// An array-like value that can have its array elements accessed as an arbitrary type `T`.
|
||||
pub trait TypedArrayLikeAccessor<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>:
|
||||
UntypedArrayLikeAccessor<'ctx, Index>
|
||||
{
|
||||
/// Casts an element from [`BasicValueEnum`] into `T`.
|
||||
fn downcast_to_type(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) -> T;
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// This function should be called with a valid index.
|
||||
unsafe fn get_typed_unchecked(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> T {
|
||||
let value = unsafe { self.get_unchecked(ctx, generator, idx, name) };
|
||||
self.downcast_to_type(ctx, generator, value)
|
||||
}
|
||||
|
||||
/// Returns the data at the `idx`-th index.
|
||||
fn get_typed(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> T {
|
||||
let value = self.get(ctx, generator, idx, name);
|
||||
self.downcast_to_type(ctx, generator, value)
|
||||
}
|
||||
}
|
||||
|
||||
/// An array-like value that can have its array elements mutated as an arbitrary type `T`.
|
||||
pub trait TypedArrayLikeMutator<'ctx, G: CodeGenerator + ?Sized, T, Index = IntValue<'ctx>>:
|
||||
UntypedArrayLikeMutator<'ctx, Index>
|
||||
{
|
||||
/// Casts an element from T into [`BasicValueEnum`].
|
||||
fn upcast_from_type(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
value: T,
|
||||
) -> BasicValueEnum<'ctx>;
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// This function should be called with a valid index.
|
||||
unsafe fn set_typed_unchecked(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
idx: &Index,
|
||||
value: T,
|
||||
) {
|
||||
let value = self.upcast_from_type(ctx, generator, value);
|
||||
unsafe { self.set_unchecked(ctx, generator, idx, value) }
|
||||
}
|
||||
|
||||
/// Sets the data at the `idx`-th index.
|
||||
fn set_typed(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &Index,
|
||||
value: T,
|
||||
) {
|
||||
let value = self.upcast_from_type(ctx, generator, value);
|
||||
self.set(ctx, generator, idx, value);
|
||||
}
|
||||
}
|
||||
|
||||
/// An adapter for constraining untyped array values as typed values.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct TypedArrayLikeAdapter<
|
||||
'ctx,
|
||||
G: CodeGenerator + ?Sized,
|
||||
T,
|
||||
Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>,
|
||||
> {
|
||||
adapted: Adapted,
|
||||
downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T,
|
||||
upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> TypedArrayLikeAdapter<'ctx, G, T, Adapted>
|
||||
where
|
||||
Adapted: ArrayLikeValue<'ctx>,
|
||||
{
|
||||
/// Creates a [`TypedArrayLikeAdapter`].
|
||||
///
|
||||
/// * `adapted` - The value to be adapted.
|
||||
/// * `downcast_fn` - The function converting a [`BasicValueEnum`] into a `T`.
|
||||
/// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`].
|
||||
pub fn from(
|
||||
adapted: Adapted,
|
||||
downcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, BasicValueEnum<'ctx>) -> T,
|
||||
upcast_fn: fn(&CodeGenContext<'ctx, '_>, &G, T) -> BasicValueEnum<'ctx>,
|
||||
) -> Self {
|
||||
TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, G: CodeGenerator + ?Sized, T, Adapted> ArrayLikeValue<'ctx>
|
||||
for TypedArrayLikeAdapter<'ctx, G, T, Adapted>
|
||||
where
|
||||
Adapted: ArrayLikeValue<'ctx>,
|
||||
{
|
||||
fn element_type<CG: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &CG,
|
||||
) -> AnyTypeEnum<'ctx> {
|
||||
self.adapted.element_type(ctx, generator)
|
||||
}
|
||||
|
||||
fn base_ptr<CG: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &CG,
|
||||
) -> PointerValue<'ctx> {
|
||||
self.adapted.base_ptr(ctx, generator)
|
||||
}
|
||||
|
||||
fn size<CG: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &CG,
|
||||
) -> IntValue<'ctx> {
|
||||
self.adapted.size(ctx, generator)
|
||||
}
|
||||
|
||||
fn as_slice_value<CG: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &CG,
|
||||
) -> ArraySliceValue<'ctx> {
|
||||
self.adapted.as_slice_value(ctx, generator)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index>
|
||||
for TypedArrayLikeAdapter<'ctx, G, T, Adapted>
|
||||
where
|
||||
Adapted: ArrayLikeIndexer<'ctx, Index>,
|
||||
{
|
||||
unsafe fn ptr_offset_unchecked<CG: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &CG,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||
}
|
||||
|
||||
fn ptr_offset<CG: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut CG,
|
||||
idx: &Index,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
self.adapted.ptr_offset(ctx, generator, idx, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index>
|
||||
for TypedArrayLikeAdapter<'ctx, G, T, Adapted>
|
||||
where
|
||||
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
|
||||
{
|
||||
}
|
||||
impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index>
|
||||
for TypedArrayLikeAdapter<'ctx, G, T, Adapted>
|
||||
where
|
||||
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
|
||||
{
|
||||
}
|
||||
|
||||
impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, G, T, Index>
|
||||
for TypedArrayLikeAdapter<'ctx, G, T, Adapted>
|
||||
where
|
||||
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
|
||||
{
|
||||
fn downcast_to_type(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) -> T {
|
||||
(self.downcast_fn)(ctx, generator, value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, G: CodeGenerator + ?Sized, T, Index, Adapted> TypedArrayLikeMutator<'ctx, G, T, Index>
|
||||
for TypedArrayLikeAdapter<'ctx, G, T, Adapted>
|
||||
where
|
||||
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
|
||||
{
|
||||
fn upcast_from_type(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
value: T,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
(self.upcast_fn)(ctx, generator, value)
|
||||
}
|
||||
}
|
||||
|
||||
/// An LLVM value representing an array slice, consisting of a pointer to the data and the size of
|
||||
/// the slice.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArraySliceValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>);
|
||||
|
||||
impl<'ctx> ArraySliceValue<'ctx> {
|
||||
/// Creates an [`ArraySliceValue`] from a [`PointerValue`] and its size.
|
||||
#[must_use]
|
||||
pub fn from_ptr_val(
|
||||
ptr: PointerValue<'ctx>,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
ArraySliceValue(ptr, size, name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> From<ArraySliceValue<'ctx>> for PointerValue<'ctx> {
|
||||
fn from(value: ArraySliceValue<'ctx>) -> Self {
|
||||
value.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> {
|
||||
fn element_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
_: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> AnyTypeEnum<'ctx> {
|
||||
self.0.get_type().get_element_type()
|
||||
}
|
||||
|
||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
_: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> PointerValue<'ctx> {
|
||||
self.0
|
||||
}
|
||||
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
_: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> IntValue<'ctx> {
|
||||
self.1
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
|
||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name.or(self.2).map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
debug_assert_eq!(idx.get_type(), ctx.get_size_type());
|
||||
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
"0:IndexError",
|
||||
"list index out of range",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {}
|
||||
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {}
|
236
nac3core/src/codegen/values/list.rs
Normal file
236
nac3core/src/codegen/values/list.rs
Normal file
@ -0,0 +1,236 @@
|
||||
use inkwell::{
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
|
||||
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
|
||||
use super::{
|
||||
structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue,
|
||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
};
|
||||
use crate::codegen::{
|
||||
types::{
|
||||
structure::{StructField, StructProxyType},
|
||||
ListType, ProxyType,
|
||||
},
|
||||
{CodeGenContext, CodeGenerator},
|
||||
};
|
||||
|
||||
/// Proxy type for accessing a `list` value in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ListValue<'ctx> {
|
||||
value: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
}
|
||||
|
||||
impl<'ctx> ListValue<'ctx> {
|
||||
/// Creates an [`ListValue`] from a [`PointerValue`].
|
||||
#[must_use]
|
||||
pub fn from_struct_value<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
val: StructValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
let pval = generator
|
||||
.gen_var_alloc(
|
||||
ctx,
|
||||
val.get_type().into(),
|
||||
name.map(|name| format!("{name}.addr")).as_deref(),
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(pval, val).unwrap();
|
||||
Self::from_pointer_value(pval, llvm_usize, name)
|
||||
}
|
||||
|
||||
/// Creates an [`ListValue`] from a [`PointerValue`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_value(
|
||||
ptr: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
|
||||
|
||||
ListValue { value: ptr, llvm_usize, name }
|
||||
}
|
||||
|
||||
fn items_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields().items
|
||||
}
|
||||
|
||||
/// Stores the array of data elements `data` into this instance.
|
||||
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
||||
self.items_field().store(ctx, self.value, data, self.name);
|
||||
}
|
||||
|
||||
/// Convenience method for creating a new array storing data elements with the given element
|
||||
/// type `elem_ty` and `size`.
|
||||
///
|
||||
/// If `size` is [None], the size stored in the field of this instance is used instead. If
|
||||
/// `size` is resolved to `0` at runtime, `(T*) 0` will be assigned to `data`.
|
||||
pub fn create_data(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: BasicTypeEnum<'ctx>,
|
||||
size: Option<IntValue<'ctx>>,
|
||||
) {
|
||||
let size = size.unwrap_or_else(|| self.load_size(ctx, None));
|
||||
|
||||
let data = ctx
|
||||
.builder
|
||||
.build_select(
|
||||
ctx.builder
|
||||
.build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "")
|
||||
.unwrap(),
|
||||
ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(),
|
||||
elem_ty.ptr_type(AddressSpace::default()).const_zero(),
|
||||
"",
|
||||
)
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
self.store_data(ctx, data);
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||
/// on the field.
|
||||
#[must_use]
|
||||
pub fn data(&self) -> ListDataProxy<'ctx, '_> {
|
||||
ListDataProxy(self)
|
||||
}
|
||||
|
||||
fn len_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
|
||||
self.get_type().get_fields().len
|
||||
}
|
||||
|
||||
/// Stores the `size` of this `list` into this instance.
|
||||
pub fn store_size(&self, ctx: &CodeGenContext<'ctx, '_>, size: IntValue<'ctx>) {
|
||||
debug_assert_eq!(size.get_type(), ctx.get_size_type());
|
||||
|
||||
self.len_field().store(ctx, self.value, size, self.name);
|
||||
}
|
||||
|
||||
/// Returns the size of this `list` as a value.
|
||||
pub fn load_size(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> IntValue<'ctx> {
|
||||
self.len_field().load(ctx, self.value, name)
|
||||
}
|
||||
|
||||
/// Returns an instance of [`ListValue`] with the `items` pointer cast to `i8*`.
|
||||
#[must_use]
|
||||
pub fn as_i8_list(&self, ctx: &CodeGenContext<'ctx, '_>) -> ListValue<'ctx> {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_list_i8 = <Self as ProxyValue>::Type::new(ctx, &llvm_i8);
|
||||
|
||||
Self::from_pointer_value(
|
||||
ctx.builder.build_pointer_cast(self.value, llvm_list_i8.as_abi_type(), "").unwrap(),
|
||||
self.llvm_usize,
|
||||
self.name,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> {
|
||||
type ABI = PointerValue<'ctx>;
|
||||
type Base = PointerValue<'ctx>;
|
||||
type Type = ListType<'ctx>;
|
||||
|
||||
fn get_type(&self) -> Self::Type {
|
||||
ListType::from_pointer_type(self.as_base_value().get_type(), self.llvm_usize)
|
||||
}
|
||||
|
||||
fn as_base_value(&self) -> Self::Base {
|
||||
self.value
|
||||
}
|
||||
|
||||
fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI {
|
||||
self.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyValue<'ctx> for ListValue<'ctx> {}
|
||||
|
||||
impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> {
|
||||
fn from(value: ListValue<'ctx>) -> Self {
|
||||
value.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy type for accessing the `data` array of an `list` instance in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ListDataProxy<'ctx, 'a>(&'a ListValue<'ctx>);
|
||||
|
||||
impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
|
||||
fn element_type<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
_: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> AnyTypeEnum<'ctx> {
|
||||
self.0.value.get_type().get_element_type()
|
||||
}
|
||||
|
||||
fn base_ptr<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> PointerValue<'ctx> {
|
||||
self.0.items_field().load(ctx, self.0.value, self.0.name)
|
||||
}
|
||||
|
||||
fn size<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
_: &G,
|
||||
) -> IntValue<'ctx> {
|
||||
self.0.load_size(ctx, None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
|
||||
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn ptr_offset<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut G,
|
||||
idx: &IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
debug_assert_eq!(idx.get_type(), ctx.get_size_type());
|
||||
|
||||
let size = self.size(ctx, generator);
|
||||
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
"0:IndexError",
|
||||
"list index out of range",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {}
|
||||
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {}
|
44
nac3core/src/codegen/values/mod.rs
Normal file
44
nac3core/src/codegen/values/mod.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use inkwell::{types::IntType, values::BasicValue};
|
||||
|
||||
use super::{types::ProxyType, CodeGenContext};
|
||||
pub use array::*;
|
||||
pub use list::*;
|
||||
pub use range::*;
|
||||
pub use tuple::*;
|
||||
|
||||
mod array;
|
||||
mod list;
|
||||
pub mod ndarray;
|
||||
mod range;
|
||||
pub mod structure;
|
||||
mod tuple;
|
||||
pub mod utils;
|
||||
|
||||
/// A LLVM type that is used to represent a non-primitive value in NAC3.
|
||||
pub trait ProxyValue<'ctx>: Into<Self::Base> {
|
||||
/// The ABI type of LLVM values represented by this instance.
|
||||
type ABI: BasicValue<'ctx>;
|
||||
|
||||
/// The type of LLVM values represented by this instance.
|
||||
type Base: BasicValue<'ctx>;
|
||||
|
||||
/// The type of this value.
|
||||
type Type: ProxyType<'ctx, Value = Self>;
|
||||
|
||||
/// Checks whether `value` can be represented by this [`ProxyValue`].
|
||||
fn is_instance(value: impl BasicValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
|
||||
Self::Type::is_representable(value.as_basic_value_enum().get_type(), llvm_usize)
|
||||
}
|
||||
|
||||
/// Returns the [type][ProxyType] of this value.
|
||||
fn get_type(&self) -> Self::Type;
|
||||
|
||||
/// Returns the [base value][Self::Base] of this proxy.
|
||||
fn as_base_value(&self) -> Self::Base;
|
||||
|
||||
/// Returns this proxy as its ABI value, i.e. the expected value representation if a value
|
||||
/// represented by this [`ProxyValue`] is being passed into or returned from a function.
|
||||
///
|
||||
/// See [`CodeGenContext::get_llvm_abi_type`].
|
||||
fn as_abi_value(&self, ctx: &CodeGenContext<'ctx, '_>) -> Self::ABI;
|
||||
}
|
262
nac3core/src/codegen/values/ndarray/broadcast.rs
Normal file
262
nac3core/src/codegen/values/ndarray/broadcast.rs
Normal file
@ -0,0 +1,262 @@
|
||||
use inkwell::{
|
||||
types::IntType,
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::{
|
||||
irrt,
|
||||
types::{
|
||||
ndarray::{NDArrayType, ShapeEntryType},
|
||||
structure::{StructField, StructProxyType},
|
||||
ProxyType,
|
||||
},
|
||||
values::{
|
||||
ndarray::NDArrayValue, structure::StructProxyValue, ArrayLikeIndexer, ArrayLikeValue,
|
||||
ArraySliceValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||
TypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ShapeEntryValue<'ctx> {
|
||||
value: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
}
|
||||
|
||||
impl<'ctx> ShapeEntryValue<'ctx> {
|
||||
/// Creates an [`ShapeEntryValue`] from a [`StructValue`].
|
||||
#[must_use]
|
||||
pub fn from_struct_value<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
val: StructValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
let pval = generator
|
||||
.gen_var_alloc(
|
||||
ctx,
|
||||
val.get_type().into(),
|
||||
name.map(|name| format!("{name}.addr")).as_deref(),
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(pval, val).unwrap();
|
||||
Self::from_pointer_value(pval, llvm_usize, name)
|
||||
}
|
||||
|
||||
/// Creates an [`ShapeEntryValue`] from a [`PointerValue`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_value(
|
||||
ptr: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
|
||||
|
||||
Self { value: ptr, llvm_usize, name }
|
||||
}
|
||||
|
||||
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
|
||||
self.get_type().get_fields().ndims
|
||||
}
|
||||
|
||||
/// Stores the number of dimensions into this value.
|
||||
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
||||
self.ndims_field().store(ctx, self.value, value, self.name);
|
||||
}
|
||||
|
||||
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields().shape
|
||||
}
|
||||
|
||||
/// Stores the shape into this value.
|
||||
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||
self.shape_field().store(ctx, self.value, value, self.name);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyValue<'ctx> for ShapeEntryValue<'ctx> {
|
||||
type ABI = PointerValue<'ctx>;
|
||||
type Base = PointerValue<'ctx>;
|
||||
type Type = ShapeEntryType<'ctx>;
|
||||
|
||||
fn get_type(&self) -> Self::Type {
|
||||
Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize)
|
||||
}
|
||||
|
||||
fn as_base_value(&self) -> Self::Base {
|
||||
self.value
|
||||
}
|
||||
|
||||
fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI {
|
||||
self.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyValue<'ctx> for ShapeEntryValue<'ctx> {}
|
||||
|
||||
impl<'ctx> From<ShapeEntryValue<'ctx>> for PointerValue<'ctx> {
|
||||
fn from(value: ShapeEntryValue<'ctx>) -> Self {
|
||||
value.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Create a broadcast view on this ndarray with a target shape.
|
||||
///
|
||||
/// The input shape will be checked to make sure that it contains no negative values.
|
||||
///
|
||||
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
|
||||
/// The caller has to figure this out for this function.
|
||||
/// * `target_shape` - An array pointer pointing to the target shape.
|
||||
#[must_use]
|
||||
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target_ndims: u64,
|
||||
target_shape: &impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
|
||||
) -> Self {
|
||||
assert!(self.ndims <= target_ndims);
|
||||
assert_eq!(target_shape.element_type(ctx, generator), self.llvm_usize.into());
|
||||
|
||||
let broadcast_ndarray = NDArrayType::new(ctx, self.dtype, target_ndims)
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
broadcast_ndarray.copy_shape_from_array(
|
||||
generator,
|
||||
ctx,
|
||||
target_shape.base_ptr(ctx, generator),
|
||||
);
|
||||
|
||||
irrt::ndarray::call_nac3_ndarray_broadcast_to(ctx, *self, broadcast_ndarray);
|
||||
broadcast_ndarray
|
||||
}
|
||||
}
|
||||
|
||||
/// A result produced by [`broadcast_all_ndarrays`]
|
||||
#[derive(Clone)]
|
||||
pub struct BroadcastAllResult<'ctx, G: CodeGenerator + ?Sized> {
|
||||
/// The statically known `ndims` of the broadcast result.
|
||||
pub ndims: u64,
|
||||
|
||||
/// The broadcasting shape.
|
||||
pub shape: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>,
|
||||
|
||||
/// Broadcasted views on the inputs.
|
||||
///
|
||||
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
|
||||
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
|
||||
/// is the same as the input.
|
||||
pub ndarrays: Vec<NDArrayValue<'ctx>>,
|
||||
}
|
||||
|
||||
/// Helper function to call [`irrt::ndarray::call_nac3_ndarray_broadcast_shapes`].
|
||||
fn broadcast_shapes<'ctx, G, Shape>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
in_shape_entries: &[(ArraySliceValue<'ctx>, u64)], // (shape, shape's length/ndims)
|
||||
broadcast_ndims: u64,
|
||||
broadcast_shape: &Shape,
|
||||
) where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Shape: TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>
|
||||
+ TypedArrayLikeMutator<'ctx, G, IntValue<'ctx>>,
|
||||
{
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_shape_ty = ShapeEntryType::new(ctx);
|
||||
|
||||
assert!(in_shape_entries
|
||||
.iter()
|
||||
.all(|entry| entry.0.element_type(ctx, generator) == llvm_usize.into()));
|
||||
assert_eq!(broadcast_shape.element_type(ctx, generator), llvm_usize.into());
|
||||
|
||||
// Prepare input shape entries to be passed to `call_nac3_ndarray_broadcast_shapes`.
|
||||
let num_shape_entries =
|
||||
llvm_usize.const_int(u64::try_from(in_shape_entries.len()).unwrap(), false);
|
||||
let shape_entries = llvm_shape_ty.array_alloca(ctx, num_shape_entries, None);
|
||||
for (i, (in_shape, in_ndims)) in in_shape_entries.iter().enumerate() {
|
||||
let pshape_entry = unsafe {
|
||||
shape_entries.ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(i as u64, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
let shape_entry = llvm_shape_ty.map_pointer_value(pshape_entry, None);
|
||||
|
||||
let in_ndims = llvm_usize.const_int(*in_ndims, false);
|
||||
shape_entry.store_ndims(ctx, in_ndims);
|
||||
|
||||
shape_entry.store_shape(ctx, in_shape.base_ptr(ctx, generator));
|
||||
}
|
||||
|
||||
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims, false);
|
||||
irrt::ndarray::call_nac3_ndarray_broadcast_shapes(
|
||||
generator,
|
||||
ctx,
|
||||
num_shape_entries,
|
||||
shape_entries,
|
||||
broadcast_ndims,
|
||||
broadcast_shape,
|
||||
);
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Broadcast all ndarrays according to
|
||||
/// [`np.broadcast()`](https://numpy.org/doc/stable/reference/generated/numpy.broadcast.html)
|
||||
/// and return a [`BroadcastAllResult`] containing all the information of the result of the
|
||||
/// broadcast operation.
|
||||
pub fn broadcast<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarrays: &[NDArrayValue<'ctx>],
|
||||
) -> BroadcastAllResult<'ctx, G> {
|
||||
assert!(!ndarrays.is_empty());
|
||||
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
|
||||
// Infer the broadcast output ndims.
|
||||
let broadcast_ndims_int =
|
||||
ndarrays.iter().map(|ndarray| ndarray.get_type().ndims()).max().unwrap();
|
||||
assert!(self.ndims() >= broadcast_ndims_int);
|
||||
|
||||
let broadcast_ndims = llvm_usize.const_int(broadcast_ndims_int, false);
|
||||
let broadcast_shape = ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, broadcast_ndims, "").unwrap(),
|
||||
broadcast_ndims,
|
||||
None,
|
||||
);
|
||||
let broadcast_shape = TypedArrayLikeAdapter::from(
|
||||
broadcast_shape,
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
let shape_entries = ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| {
|
||||
(ndarray.shape().as_slice_value(ctx, generator), ndarray.get_type().ndims())
|
||||
})
|
||||
.collect_vec();
|
||||
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, &broadcast_shape);
|
||||
|
||||
// Broadcast all the inputs to shape `dst_shape`.
|
||||
let broadcast_ndarrays = ndarrays
|
||||
.iter()
|
||||
.map(|ndarray| {
|
||||
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, &broadcast_shape)
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
BroadcastAllResult {
|
||||
ndims: broadcast_ndims_int,
|
||||
shape: broadcast_shape,
|
||||
ndarrays: broadcast_ndarrays,
|
||||
}
|
||||
}
|
||||
}
|
223
nac3core/src/codegen/values/ndarray/contiguous.rs
Normal file
223
nac3core/src/codegen/values/ndarray/contiguous.rs
Normal file
@ -0,0 +1,223 @@
|
||||
use inkwell::{
|
||||
types::{BasicType, BasicTypeEnum, IntType},
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
|
||||
use super::NDArrayValue;
|
||||
use crate::codegen::{
|
||||
stmt::gen_if_callback,
|
||||
types::{
|
||||
ndarray::{ContiguousNDArrayType, NDArrayType},
|
||||
structure::{StructField, StructProxyType},
|
||||
},
|
||||
values::{structure::StructProxyValue, ArrayLikeValue, ProxyValue},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ContiguousNDArrayValue<'ctx> {
|
||||
value: PointerValue<'ctx>,
|
||||
item: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
}
|
||||
|
||||
impl<'ctx> ContiguousNDArrayValue<'ctx> {
|
||||
/// Creates an [`ContiguousNDArrayValue`] from a [`StructValue`].
|
||||
#[must_use]
|
||||
pub fn from_struct_value<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
val: StructValue<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
let pval = generator
|
||||
.gen_var_alloc(
|
||||
ctx,
|
||||
val.get_type().into(),
|
||||
name.map(|name| format!("{name}.addr")).as_deref(),
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(pval, val).unwrap();
|
||||
Self::from_pointer_value(pval, dtype, llvm_usize, name)
|
||||
}
|
||||
|
||||
/// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_value(
|
||||
ptr: PointerValue<'ctx>,
|
||||
dtype: BasicTypeEnum<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
|
||||
|
||||
Self { value: ptr, item: dtype, llvm_usize, name }
|
||||
}
|
||||
|
||||
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
|
||||
self.get_type().get_fields().ndims
|
||||
}
|
||||
|
||||
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
||||
self.ndims_field().store(ctx, self.as_abi_value(ctx), value, self.name);
|
||||
}
|
||||
|
||||
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields().shape
|
||||
}
|
||||
|
||||
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||
self.shape_field().store(ctx, self.as_abi_value(ctx), value, self.name);
|
||||
}
|
||||
|
||||
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.shape_field().load(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields().data
|
||||
}
|
||||
|
||||
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||
self.data_field().store(ctx, self.as_abi_value(ctx), value, self.name);
|
||||
}
|
||||
|
||||
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.data_field().load(ctx, self.value, self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {
|
||||
type ABI = PointerValue<'ctx>;
|
||||
type Base = PointerValue<'ctx>;
|
||||
type Type = ContiguousNDArrayType<'ctx>;
|
||||
|
||||
fn get_type(&self) -> Self::Type {
|
||||
<Self as ProxyValue<'ctx>>::Type::from_pointer_type(
|
||||
self.as_base_value().get_type(),
|
||||
self.item,
|
||||
self.llvm_usize,
|
||||
)
|
||||
}
|
||||
|
||||
fn as_base_value(&self) -> Self::Base {
|
||||
self.value
|
||||
}
|
||||
|
||||
fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI {
|
||||
self.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {}
|
||||
|
||||
impl<'ctx> From<ContiguousNDArrayValue<'ctx>> for PointerValue<'ctx> {
|
||||
fn from(value: ContiguousNDArrayValue<'ctx>) -> Self {
|
||||
value.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Create a [`ContiguousNDArrayValue`] from the contents of this ndarray.
|
||||
///
|
||||
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
|
||||
///
|
||||
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the
|
||||
/// `data` field of the returned [`ContiguousNDArrayValue`] and copy contents of this ndarray to
|
||||
/// there.
|
||||
///
|
||||
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created
|
||||
/// [`ContiguousNDArrayValue`] will share memory with this ndarray.
|
||||
pub fn make_contiguous_ndarray<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> ContiguousNDArrayValue<'ctx> {
|
||||
let result =
|
||||
ContiguousNDArrayType::new(ctx, &self.dtype).alloca_var(generator, ctx, self.name);
|
||||
|
||||
// Set ndims and shape.
|
||||
let ndims = self.llvm_usize.const_int(self.ndims, false);
|
||||
result.store_ndims(ctx, ndims);
|
||||
|
||||
let shape = self.shape();
|
||||
result.store_shape(ctx, shape.base_ptr(ctx, generator));
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| Ok(self.is_c_contiguous(ctx)),
|
||||
|_, ctx| {
|
||||
// This ndarray is contiguous.
|
||||
let data = self.data_field().load(ctx, self.as_abi_value(ctx), self.name);
|
||||
let data = ctx
|
||||
.builder
|
||||
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
|
||||
.unwrap();
|
||||
result.store_data(ctx, data);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|generator, ctx| {
|
||||
// This ndarray is not contiguous. Do a full-copy on `data`. `make_copy` produces an
|
||||
// ndarray with contiguous `data`.
|
||||
let copied_ndarray = self.make_copy(generator, ctx);
|
||||
let data = copied_ndarray.data().base_ptr(ctx, generator);
|
||||
let data = ctx
|
||||
.builder
|
||||
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
|
||||
.unwrap();
|
||||
result.store_data(ctx, data);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Create an [`NDArrayValue`] from a [`ContiguousNDArrayValue`].
|
||||
///
|
||||
/// The operation is cheap. The newly created [`NDArrayValue`] will share the same memory as the
|
||||
/// [`ContiguousNDArrayValue`].
|
||||
///
|
||||
/// `ndims` has to be provided as [`NDArrayValue`] requires a statically known `ndims` value,
|
||||
/// despite the fact that the information should be contained within the
|
||||
/// [`ContiguousNDArrayValue`].
|
||||
pub fn from_contiguous_ndarray<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
carray: ContiguousNDArrayValue<'ctx>,
|
||||
ndims: u64,
|
||||
) -> Self {
|
||||
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
|
||||
|
||||
// Allocate the resulting ndarray.
|
||||
let ndarray = NDArrayType::new(ctx, carray.item, ndims).construct_uninitialized(
|
||||
generator,
|
||||
ctx,
|
||||
carray.name,
|
||||
);
|
||||
|
||||
// Copy shape and update strides
|
||||
let shape = carray.load_shape(ctx);
|
||||
ndarray.copy_shape_from_array(generator, ctx, shape);
|
||||
ndarray.set_strides_contiguous(ctx);
|
||||
|
||||
// Share data
|
||||
let data = carray.load_data(ctx);
|
||||
ndarray.store_data(
|
||||
ctx,
|
||||
ctx.builder
|
||||
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
ndarray
|
||||
}
|
||||
}
|
101
nac3core/src/codegen/values/ndarray/fold.rs
Normal file
101
nac3core/src/codegen/values/ndarray/fold.rs
Normal file
@ -0,0 +1,101 @@
|
||||
use inkwell::values::{BasicValue, BasicValueEnum};
|
||||
|
||||
use super::{NDArrayValue, NDIterValue, ScalarOrNDArray};
|
||||
use crate::codegen::{
|
||||
stmt::{gen_for_callback, BreakContinueHooks},
|
||||
types::ndarray::NDIterType,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Folds the elements of this ndarray into an accumulator value by applying `f`, returning the
|
||||
/// final value.
|
||||
///
|
||||
/// `f` has access to [`BreakContinueHooks`] to short-circuit the `fold` operation, an instance
|
||||
/// of `V` representing the current accumulated value, and an [`NDIterValue`] to get the
|
||||
/// properties of the current iterated element.
|
||||
pub fn fold<'a, G, V, F>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
init: V,
|
||||
f: F,
|
||||
) -> Result<V, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>,
|
||||
<V as TryFrom<BasicValueEnum<'ctx>>>::Error: std::fmt::Debug,
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
V,
|
||||
NDIterValue<'ctx>,
|
||||
) -> Result<V, String>,
|
||||
{
|
||||
let acc_ptr =
|
||||
generator.gen_var_alloc(ctx, init.as_basic_value_enum().get_type(), None).unwrap();
|
||||
ctx.builder.build_store(acc_ptr, init).unwrap();
|
||||
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
Some("ndarray_fold"),
|
||||
|generator, ctx| Ok(NDIterType::new(ctx).construct(generator, ctx, *self)),
|
||||
|_, ctx, nditer| Ok(nditer.has_element(ctx)),
|
||||
|generator, ctx, hooks, nditer| {
|
||||
let acc = V::try_from(ctx.builder.build_load(acc_ptr, "").unwrap()).unwrap();
|
||||
let acc = f(generator, ctx, hooks, acc, nditer)?;
|
||||
ctx.builder.build_store(acc_ptr, acc).unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, nditer| {
|
||||
nditer.next(ctx);
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
let acc = ctx.builder.build_load(acc_ptr, "").unwrap();
|
||||
Ok(V::try_from(acc).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// See [`NDArrayValue::fold`].
|
||||
///
|
||||
/// The primary differences between this function and `NDArrayValue::fold` are:
|
||||
///
|
||||
/// - The 3rd parameter of `f` is an `Option` of hooks, since `break`/`continue` hooks are not
|
||||
/// available if this instance represents a scalar value.
|
||||
/// - The 5th parameter of `f` is a [`BasicValueEnum`], since no [iterator][`NDIterValue`] will
|
||||
/// be created if this instance represents a scalar value.
|
||||
pub fn fold<'a, G, V, F>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
init: V,
|
||||
f: F,
|
||||
) -> Result<V, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
V: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>,
|
||||
<V as TryFrom<BasicValueEnum<'ctx>>>::Error: std::fmt::Debug,
|
||||
F: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Option<&BreakContinueHooks<'ctx>>,
|
||||
V,
|
||||
BasicValueEnum<'ctx>,
|
||||
) -> Result<V, String>,
|
||||
{
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(v) => f(generator, ctx, None, init, *v),
|
||||
ScalarOrNDArray::NDArray(v) => {
|
||||
v.fold(generator, ctx, init, |generator, ctx, hooks, acc, nditer| {
|
||||
let elem = nditer.get_scalar(ctx);
|
||||
f(generator, ctx, Some(&hooks), acc, elem)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
278
nac3core/src/codegen/values/ndarray/indexing.rs
Normal file
278
nac3core/src/codegen/values/ndarray/indexing.rs
Normal file
@ -0,0 +1,278 @@
|
||||
use inkwell::{
|
||||
types::IntType,
|
||||
values::{IntValue, PointerValue, StructValue},
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3parser::ast::{Expr, ExprKind};
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt,
|
||||
types::{
|
||||
ndarray::{NDArrayType, NDIndexType},
|
||||
structure::{StructField, StructProxyType},
|
||||
utils::SliceType,
|
||||
},
|
||||
values::{
|
||||
ndarray::NDArrayValue, structure::StructProxyValue, utils::RustSlice, ProxyValue,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
|
||||
/// An IRRT representation of an ndarray subscript index.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDIndexValue<'ctx> {
|
||||
value: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
}
|
||||
|
||||
impl<'ctx> NDIndexValue<'ctx> {
|
||||
/// Creates an [`NDIndexValue`] from a [`StructValue`].
|
||||
#[must_use]
|
||||
pub fn from_struct_value<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
val: StructValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
let pval = generator
|
||||
.gen_var_alloc(
|
||||
ctx,
|
||||
val.get_type().into(),
|
||||
name.map(|name| format!("{name}.addr")).as_deref(),
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(pval, val).unwrap();
|
||||
Self::from_pointer_value(pval, llvm_usize, name)
|
||||
}
|
||||
|
||||
/// Creates an [`NDIndexValue`] from a [`PointerValue`].
|
||||
#[must_use]
|
||||
pub fn from_pointer_value(
|
||||
ptr: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
|
||||
|
||||
Self { value: ptr, llvm_usize, name }
|
||||
}
|
||||
|
||||
fn type_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
|
||||
self.get_type().get_fields().type_
|
||||
}
|
||||
|
||||
pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
self.type_field().load(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
||||
self.type_field().store(ctx, self.value, value, self.name);
|
||||
}
|
||||
|
||||
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||
self.get_type().get_fields().data
|
||||
}
|
||||
|
||||
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.data_field().load(ctx, self.value, self.name)
|
||||
}
|
||||
|
||||
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||
self.data_field().store(ctx, self.value, value, self.name);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> {
|
||||
type ABI = PointerValue<'ctx>;
|
||||
type Base = PointerValue<'ctx>;
|
||||
type Type = NDIndexType<'ctx>;
|
||||
|
||||
fn get_type(&self) -> Self::Type {
|
||||
Self::Type::from_pointer_type(self.value.get_type(), self.llvm_usize)
|
||||
}
|
||||
|
||||
fn as_base_value(&self) -> Self::Base {
|
||||
self.value
|
||||
}
|
||||
|
||||
fn as_abi_value(&self, _: &CodeGenContext<'ctx, '_>) -> Self::ABI {
|
||||
self.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructProxyValue<'ctx> for NDIndexValue<'ctx> {}
|
||||
|
||||
impl<'ctx> From<NDIndexValue<'ctx>> for PointerValue<'ctx> {
|
||||
fn from(value: NDIndexValue<'ctx>) -> Self {
|
||||
value.as_base_value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Get the expected `ndims` after indexing with `indices`.
|
||||
#[must_use]
|
||||
fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> u64 {
|
||||
let mut ndims = self.ndims;
|
||||
|
||||
for index in indices {
|
||||
match index {
|
||||
RustNDIndex::SingleElement(_) => {
|
||||
ndims -= 1; // Single elements decrements ndims
|
||||
}
|
||||
RustNDIndex::NewAxis => {
|
||||
ndims += 1; // `np.newaxis` / `none` adds a new axis
|
||||
}
|
||||
RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
ndims
|
||||
}
|
||||
|
||||
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
|
||||
///
|
||||
/// This function behaves like NumPy's ndarray indexing, but if the indices index
|
||||
/// into a single element, an unsized ndarray is returned.
|
||||
#[must_use]
|
||||
pub fn index<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
indices: &[RustNDIndex<'ctx>],
|
||||
) -> Self {
|
||||
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
|
||||
let dst_ndarray = NDArrayType::new(ctx, self.dtype, dst_ndims)
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
|
||||
let indices = NDIndexType::new(ctx).construct_ndindices(generator, ctx, indices);
|
||||
irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray);
|
||||
|
||||
dst_ndarray
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenience enum representing a [`NDIndexValue`].
|
||||
// TODO: Rename to CTConstNDIndex
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RustNDIndex<'ctx> {
|
||||
SingleElement(IntValue<'ctx>),
|
||||
Slice(RustSlice<'ctx>),
|
||||
NewAxis,
|
||||
Ellipsis,
|
||||
}
|
||||
|
||||
impl<'ctx> RustNDIndex<'ctx> {
|
||||
/// Generate LLVM code to transform an ndarray subscript expression to
|
||||
/// its list of [`RustNDIndex`]
|
||||
///
|
||||
/// i.e.,
|
||||
/// ```python
|
||||
/// my_ndarray[::3, 1, :2:]
|
||||
/// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es
|
||||
/// ```
|
||||
pub fn from_subscript_expr<G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
subscript: &Expr<Option<Type>>,
|
||||
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
|
||||
// Annoying notes about `slice`
|
||||
// - `my_array[5]`
|
||||
// - slice is a `Constant`
|
||||
// - `my_array[:5]`
|
||||
// - slice is a `Slice`
|
||||
// - `my_array[:]`
|
||||
// - slice is a `Slice`, but lower upper step would all be `Option::None`
|
||||
// - `my_array[:, :]`
|
||||
// - slice is now a `Tuple` of two `Slice`-s
|
||||
//
|
||||
// In summary:
|
||||
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
|
||||
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
|
||||
//
|
||||
// So we first "flatten" out the slice expression
|
||||
let index_exprs = match &subscript.node {
|
||||
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
|
||||
_ => vec![subscript],
|
||||
};
|
||||
|
||||
// Process all index expressions
|
||||
let mut rust_ndindices: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
|
||||
for index_expr in index_exprs {
|
||||
// NOTE: Currently nac3core's slices do not have an object representation,
|
||||
// so the code/implementation looks awkward - we have to do pattern matching on the expression
|
||||
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
|
||||
// Handle slices
|
||||
let slice = RustSlice::from_slice_expr(generator, ctx, lower, upper, step)?;
|
||||
RustNDIndex::Slice(slice)
|
||||
} else {
|
||||
// Treat and handle everything else as a single element index.
|
||||
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
generator,
|
||||
ctx.primitives.int32, // Must be int32, this checks for illegal values
|
||||
)?;
|
||||
let index = index.into_int_value();
|
||||
|
||||
RustNDIndex::SingleElement(index)
|
||||
};
|
||||
rust_ndindices.push(ndindex);
|
||||
}
|
||||
Ok(rust_ndindices)
|
||||
}
|
||||
|
||||
/// Get the value to set `NDIndex::type` for this variant.
|
||||
#[must_use]
|
||||
pub fn get_type_id(&self) -> u64 {
|
||||
// Defined in IRRT, must be in sync
|
||||
match self {
|
||||
RustNDIndex::SingleElement(_) => 0,
|
||||
RustNDIndex::Slice(_) => 1,
|
||||
RustNDIndex::NewAxis => 2,
|
||||
RustNDIndex::Ellipsis => 3,
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndexValue`].
|
||||
pub fn write_to_ndindex<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dst_ndindex: NDIndexValue<'ctx>,
|
||||
) {
|
||||
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
|
||||
// Set `dst_ndindex.type`
|
||||
dst_ndindex.store_type(ctx, ctx.ctx.i8_type().const_int(self.get_type_id(), false));
|
||||
|
||||
// Set `dst_ndindex_ptr->data`
|
||||
match self {
|
||||
RustNDIndex::SingleElement(in_index) => {
|
||||
let index_ptr = ctx.builder.build_alloca(ctx.ctx.i32_type(), "").unwrap();
|
||||
ctx.builder.build_store(index_ptr, *in_index).unwrap();
|
||||
|
||||
dst_ndindex.store_data(
|
||||
ctx,
|
||||
ctx.builder.build_pointer_cast(index_ptr, llvm_pi8, "").unwrap(),
|
||||
);
|
||||
}
|
||||
RustNDIndex::Slice(in_rust_slice) => {
|
||||
let user_slice_ptr =
|
||||
SliceType::new(ctx, ctx.ctx.i32_type()).alloca_var(generator, ctx, None);
|
||||
in_rust_slice.write_to_slice(ctx, user_slice_ptr);
|
||||
|
||||
dst_ndindex.store_data(
|
||||
ctx,
|
||||
ctx.builder.build_pointer_cast(user_slice_ptr.into(), llvm_pi8, "").unwrap(),
|
||||
);
|
||||
}
|
||||
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
|
||||
}
|
||||
}
|
||||
}
|
69
nac3core/src/codegen/values/ndarray/map.rs
Normal file
69
nac3core/src/codegen/values/ndarray/map.rs
Normal file
@ -0,0 +1,69 @@
|
||||
use inkwell::{types::BasicTypeEnum, values::BasicValueEnum};
|
||||
|
||||
use crate::codegen::{
|
||||
values::{
|
||||
ndarray::{NDArrayOut, NDArrayValue, ScalarOrNDArray},
|
||||
ProxyValue,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Map through this ndarray with an elementwise function.
|
||||
pub fn map<'a, G, Mapping>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
out: NDArrayOut<'ctx>,
|
||||
mapping: Mapping,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BasicValueEnum<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
self.get_type().broadcast_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
&[*self],
|
||||
out,
|
||||
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// Map through this [`ScalarOrNDArray`] with an elementwise function.
|
||||
///
|
||||
/// If this is a scalar, `mapping` will directly act on the scalar. This function will return a
|
||||
/// [`ScalarOrNDArray::Scalar`] of that result.
|
||||
///
|
||||
/// If this is an ndarray, `mapping` will be applied to the elements of the ndarray. A new
|
||||
/// ndarray of the results will be created and returned as a [`ScalarOrNDArray::NDArray`].
|
||||
pub fn map<'a, G, Mapping>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ret_dtype: BasicTypeEnum<'ctx>,
|
||||
mapping: Mapping,
|
||||
) -> Result<ScalarOrNDArray<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BasicValueEnum<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ScalarOrNDArray::broadcasting_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
&[*self],
|
||||
ret_dtype,
|
||||
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
|
||||
)
|
||||
}
|
||||
}
|
323
nac3core/src/codegen/values/ndarray/matmul.rs
Normal file
323
nac3core/src/codegen/values/ndarray/matmul.rs
Normal file
@ -0,0 +1,323 @@
|
||||
use std::cmp::max;
|
||||
|
||||
use nac3parser::ast::Operator;
|
||||
|
||||
use super::{NDArrayOut, NDArrayValue, RustNDIndex};
|
||||
use crate::{
|
||||
codegen::{
|
||||
expr::gen_binop_expr_with_values,
|
||||
irrt,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::ndarray::NDArrayType,
|
||||
values::{
|
||||
ArrayLikeValue, ArraySliceValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
toplevel::helper::arraylike_flatten_element_type,
|
||||
typecheck::{magic_methods::Binop, typedef::Type},
|
||||
};
|
||||
|
||||
/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`.
|
||||
///
|
||||
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||
fn matmul_at_least_2d<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
dst_dtype: Type,
|
||||
(in_a_ty, in_a): (Type, NDArrayValue<'ctx>),
|
||||
(in_b_ty, in_b): (Type, NDArrayValue<'ctx>),
|
||||
) -> NDArrayValue<'ctx> {
|
||||
assert!(in_a.ndims >= 2, "in_a (which is {}) must be >= 2", in_a.ndims);
|
||||
assert!(in_b.ndims >= 2, "in_b (which is {}) must be >= 2", in_b.ndims);
|
||||
|
||||
let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_a_ty);
|
||||
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, in_b_ty);
|
||||
|
||||
let llvm_usize = ctx.get_size_type();
|
||||
let llvm_dst_dtype = ctx.get_llvm_type(generator, dst_dtype);
|
||||
|
||||
// Deduce ndims of the result of matmul.
|
||||
let ndims_int = max(in_a.ndims, in_b.ndims);
|
||||
let ndims = llvm_usize.const_int(ndims_int, false);
|
||||
|
||||
// Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the
|
||||
// destination ndarray to store the result of matmul.
|
||||
let (lhs, rhs, dst) = {
|
||||
let in_lhs_ndims = llvm_usize.const_int(in_a.ndims, false);
|
||||
let in_lhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
in_a.shape().base_ptr(ctx, generator),
|
||||
in_lhs_ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let in_rhs_ndims = llvm_usize.const_int(in_b.ndims, false);
|
||||
let in_rhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
in_b.shape().base_ptr(ctx, generator),
|
||||
in_rhs_ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let lhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
|
||||
ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let rhs_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
|
||||
ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
let dst_shape = TypedArrayLikeAdapter::from(
|
||||
ArraySliceValue::from_ptr_val(
|
||||
ctx.builder.build_array_alloca(llvm_usize, ndims, "").unwrap(),
|
||||
ndims,
|
||||
None,
|
||||
),
|
||||
|_, _, val| val.into_int_value(),
|
||||
|_, _, val| val.into(),
|
||||
);
|
||||
|
||||
// Matmul dimension compatibility is checked here.
|
||||
irrt::ndarray::call_nac3_ndarray_matmul_calculate_shapes(
|
||||
generator,
|
||||
ctx,
|
||||
&in_lhs_shape,
|
||||
&in_rhs_shape,
|
||||
ndims,
|
||||
&lhs_shape,
|
||||
&rhs_shape,
|
||||
&dst_shape,
|
||||
);
|
||||
|
||||
let lhs = in_a.broadcast_to(generator, ctx, ndims_int, &lhs_shape);
|
||||
let rhs = in_b.broadcast_to(generator, ctx, ndims_int, &rhs_shape);
|
||||
|
||||
let dst = NDArrayType::new(ctx, llvm_dst_dtype, ndims_int)
|
||||
.construct_uninitialized(generator, ctx, None);
|
||||
dst.copy_shape_from_array(generator, ctx, dst_shape.base_ptr(ctx, generator));
|
||||
unsafe {
|
||||
dst.create_data(generator, ctx);
|
||||
}
|
||||
|
||||
(lhs, rhs, dst)
|
||||
};
|
||||
|
||||
let len = unsafe {
|
||||
lhs.shape().get_typed_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(ndims_int - 1, false),
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
let at_row = i64::try_from(ndims_int - 2).unwrap();
|
||||
let at_col = i64::try_from(ndims_int - 1).unwrap();
|
||||
|
||||
let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype);
|
||||
let dst_zero = dst_dtype_llvm.const_zero();
|
||||
|
||||
dst.foreach(generator, ctx, |generator, ctx, _, hdl| {
|
||||
let pdst_ij = hdl.get_pointer(ctx);
|
||||
|
||||
ctx.builder.build_store(pdst_ij, dst_zero).unwrap();
|
||||
|
||||
let indices = hdl.get_indices::<G>();
|
||||
let i = unsafe {
|
||||
indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_row as u64, true), None)
|
||||
};
|
||||
let j = unsafe {
|
||||
indices.get_unchecked(ctx, generator, &llvm_usize.const_int(at_col as u64, true), None)
|
||||
};
|
||||
|
||||
let num_0 = llvm_usize.const_int(0, false);
|
||||
let num_1 = llvm_usize.const_int(1, false);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
num_0,
|
||||
(len, false),
|
||||
|generator, ctx, _, k| {
|
||||
// `indices` is modified to index into `a` and `b`, and restored.
|
||||
unsafe {
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_row as u64, true),
|
||||
i,
|
||||
);
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_col as u64, true),
|
||||
k.into(),
|
||||
);
|
||||
}
|
||||
let a_ik = unsafe { lhs.data().get_unchecked(ctx, generator, &indices, None) };
|
||||
|
||||
unsafe {
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_row as u64, true),
|
||||
k.into(),
|
||||
);
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_col as u64, true),
|
||||
j,
|
||||
);
|
||||
}
|
||||
let b_kj = unsafe { rhs.data().get_unchecked(ctx, generator, &indices, None) };
|
||||
|
||||
// Restore `indices`.
|
||||
unsafe {
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_row as u64, true),
|
||||
i,
|
||||
);
|
||||
indices.set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_int(at_col as u64, true),
|
||||
j,
|
||||
);
|
||||
}
|
||||
|
||||
// x = a_[...]ik * b_[...]kj
|
||||
let x = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(lhs_dtype), a_ik),
|
||||
Binop::normal(Operator::Mult),
|
||||
(&Some(rhs_dtype), b_kj),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||
|
||||
// dst_[...]ij += x
|
||||
let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap();
|
||||
let dst_ij = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(dst_dtype), dst_ij),
|
||||
Binop::normal(Operator::Add),
|
||||
(&Some(dst_dtype), x),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, dst_dtype)?;
|
||||
ctx.builder.build_store(pdst_ij, dst_ij).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
num_1,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
dst
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Perform [`np.matmul`](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
|
||||
///
|
||||
/// This function always return an [`NDArrayValue`]. You may want to use
|
||||
/// [`NDArrayValue::split_unsized`] to handle when the output could be a scalar.
|
||||
///
|
||||
/// `dst_dtype` defines the dtype of the returned ndarray.
|
||||
#[must_use]
|
||||
pub fn matmul<G: CodeGenerator>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
self_ty: Type,
|
||||
(other_ty, other): (Type, Self),
|
||||
(out_dtype, out): (Type, NDArrayOut<'ctx>),
|
||||
) -> Self {
|
||||
// Sanity check, but type inference should prevent this.
|
||||
assert!(self.ndims > 0 && other.ndims > 0, "np.matmul disallows scalar input");
|
||||
|
||||
// If both arguments are 2-D they are multiplied like conventional matrices.
|
||||
//
|
||||
// If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the
|
||||
// last two indices and broadcast accordingly.
|
||||
//
|
||||
// If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its
|
||||
// dimensions. After matrix multiplication the prepended 1 is removed.
|
||||
//
|
||||
// If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its
|
||||
// dimensions. After matrix multiplication the appended 1 is removed.
|
||||
|
||||
let new_a = if self.ndims == 1 {
|
||||
// Prepend 1 to its dimensions
|
||||
self.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis])
|
||||
} else {
|
||||
*self
|
||||
};
|
||||
|
||||
let new_b = if other.ndims == 1 {
|
||||
// Append 1 to its dimensions
|
||||
other.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis])
|
||||
} else {
|
||||
other
|
||||
};
|
||||
|
||||
// NOTE: `result` will always be a newly allocated ndarray.
|
||||
// Current implementation cannot do in-place matrix muliplication.
|
||||
let mut result =
|
||||
matmul_at_least_2d(generator, ctx, out_dtype, (self_ty, new_a), (other_ty, new_b));
|
||||
|
||||
// Postprocessing on the result to remove prepended/appended axes.
|
||||
let mut postindices = vec![];
|
||||
let zero = ctx.ctx.i32_type().const_zero();
|
||||
|
||||
if self.ndims == 1 {
|
||||
// Remove the prepended 1
|
||||
postindices.push(RustNDIndex::SingleElement(zero));
|
||||
}
|
||||
|
||||
if other.ndims == 1 {
|
||||
// Remove the appended 1
|
||||
postindices.push(RustNDIndex::Ellipsis);
|
||||
postindices.push(RustNDIndex::SingleElement(zero));
|
||||
}
|
||||
|
||||
if !postindices.is_empty() {
|
||||
result = result.index(generator, ctx, &postindices);
|
||||
}
|
||||
|
||||
match out {
|
||||
NDArrayOut::NewNDArray { .. } => result,
|
||||
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
|
||||
let result_shape = result.shape();
|
||||
out_ndarray.assert_can_be_written_by_out(generator, ctx, result_shape);
|
||||
|
||||
out_ndarray.copy_data_from(ctx, result);
|
||||
out_ndarray
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
1046
nac3core/src/codegen/values/ndarray/mod.rs
Normal file
1046
nac3core/src/codegen/values/ndarray/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user