forked from M-Labs/nac3
1
0
Fork 0

Compare commits

..

129 Commits

Author SHA1 Message Date
occheung 015714eee1 copy constructor -> clone 2024-11-28 18:52:53 +08:00
occheung 71dec251e3 ld/dwarf: remove reader resets
DWARF reader never had to reverse. Readers are already copied to achieve this effect.
Plus the position that it reverses to might be questionable.
2024-11-28 18:52:53 +08:00
occheung fce61f7b8c ld: fix dwarf sections offset calculations 2024-11-28 18:52:53 +08:00
abdul124 babc081dbd core/toplevel: update tests 2024-11-27 14:31:57 +08:00
abdul124 5337dbe23b core/toplevel: add python-like error messages for class definition 2024-11-27 14:31:57 +08:00
abdul124 f862c01412 core/toplevel: refactor composer 2024-11-27 14:31:53 +08:00
David Mak 0c9705f5f1 [meta] Apply clippy changes 2024-11-25 16:05:12 +08:00
David Mak 5f940f86d9 [artiq] Fix obtaining ndarray struct from NDArrayType 2024-11-25 15:01:39 +08:00
Sebastien Bourdeauducq 5651e00688 flake: add platformdirs artiq dependency 2024-11-22 20:30:30 +08:00
Sebastien Bourdeauducq f6745b987f bump sipyco and artiq used for profiling 2024-11-22 19:43:03 +08:00
mwojcik e0dedc6580 nac3artiq: support kernels sent by content 2024-11-22 19:38:52 +08:00
David Mak 28f574282c [core_derive] Ignore doctest in example
Causes linker errors for unknown reasons.
2024-11-22 00:00:05 +08:00
David Mak 144f0922db [core] coregen/types: Implement StructFields for NDArray
Also rename some fields to better align with their naming in numpy.
2024-11-21 14:27:00 +08:00
David Mak c58ce9c3a9 [core] codegen/types: Implement NDArray in terms of i8*
Better aligns with the future implementation of ndstrides.
2024-11-21 14:27:00 +08:00
David Mak f7e296da53 [core] irrt: Break IRRT into several impl files
Each IRRT file is now mapped to one Rust file.
2024-11-21 14:27:00 +08:00
David Mak b58c99369e [core] irrt: Update some IRRT implementation
- Change CSlice to use `void*` for better pointer compatibility
- Only include impl *.hpp files in irrt.cpp
- Refactor typedef to using declaration
- Add missing ``// namespace`
2024-11-21 14:26:58 +08:00
David Mak 1a535db558 [core] codegen: Add dtype to NDArrayType
We won't have this once NDArray is refactored to strided impl.
2024-11-20 15:35:57 +08:00
David Mak 1ba2e287a6 [core] codegen: Add Self::llvm_type to all type abstractions 2024-11-20 15:35:57 +08:00
lyken f95f979ad3 core/irrt: fix exception.hpp C++ castings 2024-11-20 15:35:57 +08:00
lyken 48e2148c0f core/toplevel/helper: add {extract,create}_ndims 2024-11-20 15:35:57 +08:00
David Mak 88e57f7120 [core_derive] Initial implementation 2024-11-20 15:35:55 +08:00
David Mak d7633c42bc [core] codegen/types: Implement StructField{,s}
Loosely based on FieldTraversal by lyken.
2024-11-19 13:46:25 +08:00
David Mak a4f53b6e6b [core] codegen: Refactor ProxyType and ProxyValue
Accepts generator+context object for generic type checking. Also
implements more default trait impl for easier delegation.
2024-11-19 13:46:25 +08:00
David Mak 9d9ead211e [core] Move Proxies to their own modules 2024-11-19 13:46:23 +08:00
David Mak 26a1b85206 [core] codegen/classes: Remove Underlying type
This is confusing and we want a better abstraction than this.
2024-11-19 13:45:55 +08:00
David Mak 2822074b2d [meta] Cleanup from upgrading Rust version
- Remove rust_2024_edition warnings, since it wouldn't be released for
another 3 months
- Fix new clippy warnings
2024-11-19 13:43:57 +08:00
David Mak fe67ed076c [meta] Update pre-commit configuration 2024-11-19 13:20:27 +08:00
David Mak 94e2414df0 [meta] Update cargo dependencies 2024-11-19 13:20:26 +08:00
Sebastien Bourdeauducq 2cee760404 turn rust_2024_compatibility lints into warnings 2024-11-16 13:41:49 +08:00
Sebastien Bourdeauducq 230982dc84 update dependencies 2024-11-16 12:40:11 +08:00
occheung 2bd3f63991 boolop: terminate both branches with *_end_bb 2024-11-16 12:06:20 +08:00
occheung b53266e9e6 artiq: use async RPC for attributes writeback 2024-11-12 12:04:01 +08:00
occheung 86eb22bbf3 artiq: main is always the last module 2024-11-12 12:03:38 +08:00
occheung beaa38047d artiq: suppress main module debug warning 2024-11-12 12:03:08 +08:00
occheung 705dc4ff1c artiq: lump return value into attributes writeback RPC 2024-11-12 12:02:35 +08:00
occheung 979209a526 binop: expand `not` operator as loglcal not 2024-11-08 17:12:01 +08:00
David Mak c3927d0ef6 [ast] Refactor lazy_static to LazyLock
It is available in Rust 1.80 and reduces a dependency.
2024-10-30 12:29:51 +08:00
David Mak 202a902cd0 [meta] Update dependencies 2024-10-30 12:29:51 +08:00
David Mak b6e2644391 [meta] Update cargo dependencies 2024-10-18 14:17:16 +08:00
David Mak 45cd01556b [meta] Apply cargo fmt 2024-10-18 14:16:42 +08:00
David Mak b6cd2a6993 [meta] Reorganize order of use declarations - Phase 3 2024-10-17 16:25:52 +08:00
David Mak a98f33e6d1 [meta] Reorganize order of use declarations - Phase 2
Some more rules:

- For module-level imports, prefer no prefix > super > crate.
- Use crate instead of super if super refers to the crate-level module
2024-10-17 15:57:33 +08:00
David Mak 5839badadd [standalone] Update globals.py with type-inferred global var 2024-10-07 20:44:08 +08:00
David Mak 56c845aac4 [standalone] Add support for registering globals without type decl 2024-10-07 20:44:06 +08:00
David Mak 65a12d9ab3 [core] Refactor registration of top-level variables 2024-10-07 17:05:48 +08:00
David Mak 9c6685fa8f [core] typecheck/function_check: Fix lookup of defined ids in scope 2024-10-07 16:51:37 +08:00
David Mak 2bb788e4bb [core] codegen/expr: Materialize implicit globals
Required for when globals are read without the use of a global
declaration.
2024-10-07 13:13:20 +08:00
David Mak 42a2f243b5 [core] typecheck: Disallow redeclaration of var shadowing global 2024-10-07 13:11:00 +08:00
David Mak 3ce2eddcdc [core] typecheck/type_inferencer: Infer whether variables are global 2024-10-07 13:10:46 +08:00
David Mak 51bf126a32 [core] typecheck/type_inferencer: Differentiate global symbols
Required for analyzing use of global symbols before global declaration.
2024-10-07 12:25:00 +08:00
David Mak 1a197c67f6 [core] toplevel/composer: Reduce lock scope while analyzing function 2024-10-05 15:53:20 +08:00
David Mak 581b2f7bb2 [standalone] Add demo for global variables 2024-10-04 13:24:30 +08:00
David Mak 746329ec5d [standalone] Implement symbol resolution for globals 2024-10-04 13:24:30 +08:00
David Mak e60e8e837f [core] Add support for global statements 2024-10-04 13:24:27 +08:00
David Mak 9fdbe9695d [core] Add generator to SymbolResolver::get_symbol_value
Needed in a future commit.
2024-10-04 13:20:29 +08:00
David Mak 8065e73598 [core] toplevel/composer: Add type analysis for global variables 2024-10-04 13:20:29 +08:00
David Mak 192290889b [core] Add IdentifierInfo
Keeps track of whether an identifier refers to a global or local
variable.
2024-10-04 13:20:24 +08:00
David Mak 1407553a2f [core] Implement parsing of global variables
Globals are now parsed into symbol resolver and top level definitions.
2024-10-04 13:18:29 +08:00
David Mak c7697606e1 [core] Add TopLevelDef::Variable 2024-10-04 13:09:25 +08:00
David Mak 88d0ccbf69 [standalone] Explicit panic when encountering a compilation error
Otherwise scripts will continue to execute.
2024-10-04 13:00:16 +08:00
David Mak a43b59539c [meta] Move variables declarations closer to where they are first used 2024-10-04 13:00:16 +08:00
David Mak fe06b2806f [meta] Reorganize order of use declarations
Use declarations are now grouped into 4 groups:

- Declarations from the standard library
- Declarations from external crates
- Declarations from other crates in this project
- Declarations from within this module

Furthermore, all use declarations are grouped together to enhance
readability. super::super is also replaced by an equivalent crate::
declaration.
2024-10-04 12:52:01 +08:00
David Mak 7f6c9a25ac [meta] Update Cargo dependencies 2024-10-04 12:52:01 +08:00
Sébastien Bourdeauducq 6c8382219f msys2: get python via numpy dependencies 2024-09-30 14:27:30 +08:00
Sebastien Bourdeauducq 9274a7b96b flake: update nixpkgs 2024-09-30 14:22:40 +08:00
Sébastien Bourdeauducq d1c0fe2900 cargo: update dependencies 2024-09-30 14:14:43 +08:00
mwojcik f2c047ba57 artiq: support async rpcs
Co-authored-by: mwojcik <mw@m-labs.hk>
Co-committed-by: mwojcik <mw@m-labs.hk>
2024-09-13 12:12:13 +08:00
David Mak 5e2e77a500 [meta] Bump inkwell to v0.5 2024-09-13 11:11:14 +08:00
David Mak f3cc4702b9 [meta] Update dependencies 2024-09-13 11:11:14 +08:00
David Mak 3e92c491f5 [standalone] Add tests creating ndarrays with tuple dims 2024-09-11 15:52:43 +08:00
lyken 7f629f1579 core: fix comment in unify_call 2024-09-11 15:46:19 +08:00
lyken 5640a793e2 core: allow np_full to take tuple shapes 2024-09-11 15:46:19 +08:00
David Mak abbaa506ad [standalone] Remove redundant recreation of TargetMachine 2024-09-09 14:27:10 +08:00
David Mak f3dc02d646 [meta] Apply cargo fmt 2024-09-09 14:24:52 +08:00
David Mak ea217eaea1 [meta] Update pre-commit config
Directly invoke cargo using nix develop to avoid using the system cargo.
2024-09-09 14:24:38 +08:00
Sébastien Bourdeauducq 5a34551905 allow the use of the LLVM shared library
Which in turns allows working around the incompatibility of the LLVM static library
with Rust link-args=-rdynamic, which produces binaries that either fail to link (OpenBSD)
or segfault on startup (Linux).

The year is 2024 and compiler toolchains are still a trash fire like this.
2024-09-09 11:17:31 +08:00
Sebastien Bourdeauducq 6098b1b853 fix previous commit 2024-09-06 11:32:08 +08:00
Sebastien Bourdeauducq 668ccb1c95 nac3core: expose inkwell and nac3parser 2024-09-06 11:06:26 +08:00
Sebastien Bourdeauducq a3c624d69d update all dependencies 2024-09-06 10:21:58 +08:00
Sébastien Bourdeauducq bd06155f34 irrt: compatibility with pre-C23 compilers 2024-09-05 18:54:55 +08:00
David Mak 9c33c4209c [core] Fix type of ndarray.element_type
Should be the element type of the NDArray itself, not the pointer to its
type.
2024-08-30 22:47:38 +08:00
Sebastien Bourdeauducq 122983f11c flake: update dependencies 2024-08-30 14:45:38 +08:00
David Mak 71c3a65a31 [core] codegen/stmt: Fix obtaining return type of sret functions 2024-08-29 19:15:30 +08:00
David Mak 8c540d1033 [core] codegen/stmt: Add more casts for boolean types 2024-08-29 16:36:32 +08:00
David Mak 0cc60a3d33 [core] codegen/expr: Fix missing cast to i1 2024-08-29 16:36:32 +08:00
David Mak a59c26aa99 [artiq] Fix RPC of ndarrays from host 2024-08-29 16:08:45 +08:00
David Mak 02d93b11d1 [meta] Update dependencies 2024-08-29 14:32:21 +08:00
lyken 59cad5bfe1
standalone: clang-format demo.c 2024-08-29 10:37:24 +08:00
lyken 4318f8de84
standalone: improve src/assignment.py 2024-08-29 10:33:58 +08:00
David Mak 15ac00708a [core] Use quoted include paths instead of angled brackets
This is preferred for user-defined headers.
2024-08-28 16:37:03 +08:00
lyken c8dfdcfdea
standalone & artiq: remove class_names from resolver 2024-08-27 23:43:40 +08:00
Sébastien Bourdeauducq 600a5c8679 Revert "standalone: reformat demo.c"
This reverts commit 308edb8237.
2024-08-27 23:06:49 +08:00
lyken 22c4d25802 core/typecheck: add missing typecheck in matmul 2024-08-27 22:59:39 +08:00
lyken 308edb8237 standalone: reformat demo.c 2024-08-27 22:55:22 +08:00
lyken 9848795dcc core/irrt: add exceptions and debug utils 2024-08-27 22:55:22 +08:00
lyken 58222feed4 core/irrt: split into headers 2024-08-27 22:55:22 +08:00
lyken 518f21d174 core/irrt: build.rs capture IR defined constants 2024-08-27 22:55:22 +08:00
lyken e8e49684bf core/irrt: build.rs capture IR defined types 2024-08-27 22:55:22 +08:00
lyken b2900b4883 core/irrt: use +std=c++20 to compile
To explicitly set the C++ variant and avoid inconsistencies.
2024-08-27 22:55:22 +08:00
lyken c6dade1394 core/irrt: reformat 2024-08-27 22:55:22 +08:00
lyken 7e3fcc0845 add .clang-format 2024-08-27 22:55:22 +08:00
lyken d3b4c60d7f core/irrt: comment build.rs & move irrt to nac3core/irrt 2024-08-27 22:55:22 +08:00
abdul124 5b2b6db7ed core: improve error messages 2024-08-26 18:37:55 +08:00
abdul124 15e62f467e standalone: add tests for polymorphism 2024-08-26 18:37:55 +08:00
abdul124 2c88924ff7 core: add support for simple polymorphism 2024-08-26 18:37:55 +08:00
abdul124 a744b139ba core: allow Call and AnnAssign in init block 2024-08-26 18:37:55 +08:00
David Mak 2b2b2dbf8f [core] Fix resolution of exception names in raise short form
Previous implementation fails as `resolver.get_identifier_def` in ARTIQ
would return the exception __init__ function rather than the class.

We fix this by limiting the exception class resolution to only include
raise statements, and to force the exception name to always be treated
as a class.

Fixes #501.
2024-08-26 18:35:02 +08:00
David Mak d9f96dab33 [core] Add codegen_unreachable 2024-08-23 13:10:55 +08:00
David Mak c5ae0e7c36 [standalone] Add tests for tuple equality 2024-08-21 16:25:32 +08:00
David Mak b8dab6cf7c [standalone] Add tests for string equality 2024-08-21 16:25:32 +08:00
David Mak 4d80ba38b7 [core] codegen/expr: Implement comparison of tuples 2024-08-21 16:25:32 +08:00
David Mak 33929bda24 [core] typecheck/typedef: Add support for tuple methods 2024-08-21 16:25:32 +08:00
David Mak a8e92212c0 [core] codegen/expr: Implement string equality 2024-08-21 16:25:32 +08:00
David Mak 908271014a [core] typecheck/magic_methods: Add equality methods to string 2024-08-21 16:25:32 +08:00
David Mak c407622f5c [core] codegen/expr: Add compilation error for unsupported cmpop 2024-08-21 15:46:13 +08:00
David Mak d7952d0629 [core] codegen/expr: Fix assertions not generated for -O0 2024-08-21 15:36:54 +08:00
David Mak ca1395aed6 [core] codegen: Remove redundant return 2024-08-21 15:36:54 +08:00
David Mak 7799aa4987 [meta] Do not specify rev in dependency version 2024-08-21 15:36:54 +08:00
David Mak 76016a26ad [meta] Apply clippy suggestions 2024-08-21 13:07:57 +08:00
lyken 8532bf5206
standalone: add missing test_ndarray_ceil() run 2024-08-21 11:39:00 +08:00
lyken 2cf64d8608
apply clippy comment changes 2024-08-21 11:21:10 +08:00
lyken 706759adb2
artiq: apply cargo fmt 2024-08-21 11:21:10 +08:00
lyken b90cf2300b
core/fix: add missing lifetime in gen_for* 2024-08-21 11:05:30 +08:00
Sebastien Bourdeauducq 0fc26df29e flake: update nixpkgs 2024-08-19 23:53:15 +08:00
David Mak 0b074c2cf2 [artiq] symbol_resolver: Set private linkage for constants 2024-08-19 14:41:43 +08:00
Sébastien Bourdeauducq a0f6961e0e cargo: update dependencies 2024-08-19 13:15:03 +08:00
David Mak b1c5c2e1d4 [artiq] Fix RPC of ndarrays to host 2024-08-15 15:41:24 +08:00
David Mak 69320a6cf1 [artiq] Fix LLVM representation of strings
Should be `%str` rather than `[N x i8]`.
2024-08-14 09:30:08 +08:00
David Mak 9e0601837a core: Add compile-time feature to disable escape analysis 2024-08-14 09:29:48 +08:00
158 changed files with 11419 additions and 14599 deletions

View File

@ -1,3 +1,32 @@
BasedOnStyle: Google BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4 IndentWidth: 4
ReflowComments: false MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

View File

@ -1,24 +1,24 @@
# See https://pre-commit.com for more information # See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks # See https://pre-commit.com/hooks.html for more hooks
default_stages: [commit] default_stages: [pre-commit]
repos: repos:
- repo: local - repo: local
hooks: hooks:
- id: nac3-cargo-fmt - id: nac3-cargo-fmt
name: nac3 cargo format name: nac3 cargo format
entry: cargo entry: nix
language: system language: system
types: [file, rust] types: [file, rust]
pass_filenames: false pass_filenames: false
description: Runs cargo fmt on the codebase. description: Runs cargo fmt on the codebase.
args: [fmt] args: [develop, -c, cargo, fmt, --all]
- id: nac3-cargo-clippy - id: nac3-cargo-clippy
name: nac3 cargo clippy name: nac3 cargo clippy
entry: cargo entry: nix
language: system language: system
types: [file, rust] types: [file, rust]
pass_filenames: false pass_filenames: false
description: Runs cargo clippy on the codebase. description: Runs cargo clippy on the codebase.
args: [clippy, --tests] args: [develop, -c, cargo, clippy, --tests]

520
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@ members = [
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3core/nac3core_derive",
"nac3standalone", "nac3standalone",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",

View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1721924956, "lastModified": 1731319897,
"narHash": "sha256-Sb1jlyRO+N8jBXEX9Pg9Z1Qb8Bw9QyOgLDNMEpmjZ2M=", "narHash": "sha256-PbABj4tnbWFMfBp6OcUK5iGy1QY+/Z96ZcLpooIbuEI=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "5ad6a14c6bf098e98800b091668718c336effc95", "rev": "dc460ec76cbff0e66e269457d7b728432263166c",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -14,7 +14,6 @@
'' ''
mkdir -p $out/bin mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.clang}/bin/clang $out/bin/clang-irrt-test
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
''; '';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage { demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
@ -41,7 +40,6 @@
cargoLock = { cargoLock = {
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
cargoTestFlags = [ "--features" "test" ];
passthru.cargoLock = cargoLock; passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
@ -109,18 +107,18 @@
(pkgs.fetchFromGitHub { (pkgs.fetchFromGitHub {
owner = "m-labs"; owner = "m-labs";
repo = "sipyco"; repo = "sipyco";
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e"; rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc="; sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
}) })
(pkgs.fetchFromGitHub { (pkgs.fetchFromGitHub {
owner = "m-labs"; owner = "m-labs";
repo = "artiq"; repo = "artiq";
rev = "923ca3377d42c815f979983134ec549dc39d3ca0"; rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw="; sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
}) })
]; ];
buildInputs = [ buildInputs = [
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ])) (python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
pkgs.llvmPackages_14.llvm.out pkgs.llvmPackages_14.llvm.out
]; ];
phases = [ "buildPhase" "installPhase" ]; phases = [ "buildPhase" "installPhase" ];

View File

@ -12,15 +12,10 @@ crate-type = ["cdylib"]
itertools = "0.13" itertools = "0.13"
pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] } pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12" parking_lot = "0.12"
tempfile = "3.10" tempfile = "3.13"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" } nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.4"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features] [features]
init-llvm-profile = [] init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -112,10 +112,15 @@ def extern(function):
register_function(function) register_function(function)
return function return function
def rpc(function):
"""Decorates a function declaration defined by the core device runtime.""" def rpc(arg=None, flags={}):
register_function(function) """Decorates a function or method to be executed on the host interpreter."""
return function if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def kernel(function_or_method): def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device.""" """Decorates a function or method to be executed on the core device."""
@ -201,7 +206,7 @@ class Core:
embedding = EmbeddingMap() embedding = EmbeddingMap()
if allow_registration: if allow_registration:
compiler.analyze(registered_functions, registered_classes) compiler.analyze(registered_functions, registered_classes, set())
allow_registration = False allow_registration = False
if hasattr(method, "__self__"): if hasattr(method, "__self__"):

26
nac3artiq/demo/str_abi.py Normal file
View File

@ -0,0 +1,26 @@
from min_artiq import *
from numpy import ndarray, zeros as np_zeros
@nac3
class StrFail:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def hello(self, arg: str):
pass
@kernel
def consume_ndarray(self, arg: ndarray[str, 1]):
pass
def run(self):
self.hello("world")
self.consume_ndarray(np_zeros([10], dtype=str))
if __name__ == "__main__":
StrFail().run()

View File

@ -1,35 +1,3 @@
use nac3core::{
codegen::{
classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor},
expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
};
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{
context::Context,
module::Linkage,
types::IntType,
values::{BasicValueEnum, StructValue},
AddressSpace, IntPredicate,
};
use pyo3::{
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use itertools::Itertools;
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap}, collections::{hash_map::DefaultHasher, HashMap},
hash::{Hash, Hasher}, hash::{Hash, Hasher},
@ -38,6 +6,40 @@ use std::{
sync::Arc, sync::Arc,
}; };
use itertools::Itertools;
use pyo3::{
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use nac3core::{
codegen::{
expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size,
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
types::{NDArrayType, ProxyType},
values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
RangeValue, UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
},
inkwell::{
context::Context,
module::Linkage,
types::{BasicType, IntType},
values::{BasicValueEnum, IntValue, PointerValue, StructValue},
AddressSpace, IntPredicate, OptimizationLevel,
},
nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef},
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
};
use super::{symbol_resolver::InnerResolver, timeline::TimeFns};
/// The parallelism mode within a block. /// The parallelism mode within a block.
#[derive(Copy, Clone, Eq, PartialEq)] #[derive(Copy, Clone, Eq, PartialEq)]
enum ParallelMode { enum ParallelMode {
@ -127,7 +129,7 @@ impl<'a> ArtiqCodeGenerator<'a> {
/// (possibly indirect) `parallel` block. /// (possibly indirect) `parallel` block.
/// ///
/// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to /// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to
/// the end of the provided value name. /// the end of the provided value name.
fn timeline_update_end_max( fn timeline_update_end_max(
&mut self, &mut self,
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
@ -422,7 +424,10 @@ fn gen_rpc_tag(
} else { } else {
unreachable!() unreachable!()
}; };
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims)); assert!(
(0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims),
"Only NDArrays of sizes between 0 and 255 can be RPCed"
);
buffer.push(b'a'); buffer.push(b'a');
buffer.push((ndarray_ndims & 0xFF) as u8); buffer.push((ndarray_ndims & 0xFF) as u8);
@ -434,17 +439,395 @@ fn gen_rpc_tag(
Ok(()) Ok(())
} }
/// Formats an RPC argument to conform to the expected format required by `send_value`.
///
/// See `artiq/firmware/libproto_artiq/rpc_proto.rs` for the expected format.
fn format_rpc_arg<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
(arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize),
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let arg_slot = match &*ctx.unifier.get_ty_immutable(arg_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
// NAC3: NDArray = { usize, usize*, T* }
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None);
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let dims_buf_sz =
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
let buffer_size =
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
call_memcpy_generic(
ctx,
buffer.base_ptr(ctx, generator),
llvm_arg.ptr_to_data(ctx),
llvm_pdata_sizeof,
llvm_i1.const_zero(),
);
let pbuffer_dims_begin =
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
call_memcpy_generic(
ctx,
pbuffer_dims_begin,
llvm_arg.shape().base_ptr(ctx, generator),
dims_buf_sz,
llvm_i1.const_zero(),
);
buffer.base_ptr(ctx, generator)
}
_ => {
let arg_slot = generator
.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}")))
.unwrap();
ctx.builder.build_store(arg_slot, arg).unwrap();
ctx.builder
.build_bit_cast(arg_slot, llvm_pi8, "rpc.arg")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
};
debug_assert_eq!(arg_slot.get_type(), llvm_pi8);
arg_slot
}
/// Formats an RPC return value to conform to the expected format required by NAC3.
fn format_rpc_ret<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ret_ty: Type,
) -> Option<BasicValueEnum<'ctx>> {
// -- receive value:
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let llvm_i8 = ctx.ctx.i8_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
});
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
return None;
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
// Round `val` up to its modulo `power_of_two`
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
power_of_two: IntValue<'ctx>| {
debug_assert_eq!(
val.get_type().get_bit_width(),
power_of_two.get_type().get_bit_width()
);
let llvm_val_t = val.get_type();
let max_rem = ctx
.builder
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
.unwrap();
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
};
// Setup types
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
// Allocate the resulting ndarray
// A condition after format_rpc_ret ensures this will not be popped this off.
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
// Setup ndims
let ndims =
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
} else {
unreachable!();
};
// Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx));
/*
ndarray now:
- .ndims: initialized
- .shape: allocated but uninitialized .shape
- .data: uninitialized
*/
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let llvm_elem_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
.unwrap();
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
// (4 + 4 * ndims) bytes with 8-byte alignment
let sizeof_dims =
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
let unaligned_buffer_size =
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
let stackptr = call_stacksave(ctx, None);
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
let buffer = ctx
.builder
.build_array_alloca(
llvm_i8_8,
ctx.builder
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
.unwrap(),
"rpc.buffer",
)
.unwrap();
let buffer = ctx
.builder
.build_bit_cast(buffer, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
//
// The returned value is the number of bytes for `ndarray.data`.
let ndarray_nbytes = ctx
.build_call_or_invoke(
rpc_recv,
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
"rpc.size.next",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
// debug_assert(ndarray_nbytes > 0)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::UGT,
ndarray_nbytes,
ndarray_nbytes.get_type().const_zero(),
"",
)
.unwrap(),
"0:AssertionError",
"Unexpected RPC termination for ndarray - Expected data buffer next",
[None, None, None],
ctx.current_loc,
);
}
// Copy shape from the buffer to `ndarray.shape`.
let pbuffer_dims =
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
call_memcpy_generic(
ctx,
ndarray.shape().base_ptr(ctx, generator),
pbuffer_dims,
sizeof_dims,
llvm_i1.const_zero(),
);
// Restore stack from before allocation of buffer
call_stackrestore(ctx, stackptr);
// Allocate `ndarray.data`.
// `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate)
let num_elements =
call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let sizeof_data =
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::UGE,
sizeof_data,
ndarray_nbytes,
"",
).unwrap(),
"0:AssertionError",
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
[Some(sizeof_data), Some(ndarray_nbytes), None],
ctx.current_loc,
);
}
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
let ndarray_data_i8 =
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
// NOTE: Currently on `prehead_bb`
ctx.builder.build_unconditional_branch(head_bb).unwrap();
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.map(BasicValueEnum::into_int_value)
.unwrap();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
// Align the allocation to sizeof(T)
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
let alloc_ptr = ctx
.builder
.build_array_alloca(
llvm_elem_ty,
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
"rpc.alloc",
)
.unwrap();
let alloc_ptr =
ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
ndarray.as_base_value().into()
}
_ => {
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bit_cast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr =
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr =
ctx.builder.build_bit_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
ctx.builder.build_load(slot, "rpc.result").unwrap()
}
};
Some(result)
}
fn rpc_codegen_callback_fn<'ctx>( fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
is_async: bool,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let size_type = generator.get_size_type(ctx.ctx);
let int8 = ctx.ctx.i8_type(); let int8 = ctx.ctx.i8_type();
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_type = generator.get_size_type(ctx.ctx);
let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1 .0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
@ -517,22 +900,25 @@ fn rpc_codegen_callback_fn<'ctx>(
.0 .0
.args .args
.iter() .iter()
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty)) .map(|arg| {
.collect::<Result<Vec<_>, _>>()?; mapping
.remove(&arg.name)
.unwrap()
.to_basic_value_enum(ctx, generator, arg.ty)
.map(|llvm_val| (llvm_val, arg.ty))
})
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj { if let Some(obj) = obj {
if let ValueEnum::Static(obj) = obj.1 { if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, obj.get_const_obj(ctx, generator)); real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
} else { } else {
// should be an error here... // should be an error here...
panic!("only host object is allowed"); panic!("only host object is allowed");
} }
} }
for (i, arg) in real_params.iter().enumerate() { for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot = let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
let arg_ptr = unsafe { let arg_ptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(
args_ptr, args_ptr,
@ -545,91 +931,72 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
// call // call
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { if is_async {
ctx.module.add_function( let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
"rpc_send", ctx.module.add_function(
ctx.ctx.void_type().fn_type( "rpc_send_async",
&[ ctx.ctx.void_type().fn_type(
int32.into(), &[
tag_ptr_type.ptr_type(AddressSpace::default()).into(), int32.into(),
ptr_type.ptr_type(AddressSpace::default()).into(), tag_ptr_type.ptr_type(AddressSpace::default()).into(),
], ptr_type.ptr_type(AddressSpace::default()).into(),
false, ],
), false,
None, ),
) None,
}); )
ctx.builder });
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") ctx.builder
.unwrap(); .build_call(
rpc_send_async,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
)
.unwrap();
} else {
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
)
});
ctx.builder
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.unwrap();
}
// reclaim stack space used by arguments // reclaim stack space used by arguments
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
// -- receive value: if is_async {
// T result = { // async RPCs do not return any values
// void *ret_ptr = alloca(sizeof(T)); Ok(None)
// void *ptr = ret_ptr; } else {
// loop: int size = rpc_recv(ptr); let result = format_rpc_ret(generator, ctx, fun.0.ret);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
});
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); // An RPC returning an NDArray would not touch here.
return Ok(None); call_stackrestore(ctx, stackptr);
}
Ok(result)
} }
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
let need_load = !ret_ty.is_pointer_type();
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx
.builder
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
if need_load {
call_stackrestore(ctx, stackptr);
}
Ok(Some(result))
} }
pub fn attributes_writeback( pub fn attributes_writeback<'ctx>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
inner_resolver: &InnerResolver, inner_resolver: &InnerResolver,
host_attributes: &PyObject, host_attributes: &PyObject,
return_obj: Option<(Type, ValueEnum<'ctx>)>,
) -> Result<(), String> { ) -> Result<(), String> {
Python::with_gil(|py| -> PyResult<Result<(), String>> { Python::with_gil(|py| -> PyResult<Result<(), String>> {
let host_attributes: &PyList = host_attributes.downcast(py)?; let host_attributes: &PyList = host_attributes.downcast(py)?;
@ -639,6 +1006,11 @@ pub fn attributes_writeback(
let zero = int32.const_zero(); let zero = int32.const_zero();
let mut values = Vec::new(); let mut values = Vec::new();
let mut scratch_buffer = Vec::new(); let mut scratch_buffer = Vec::new();
if let Some((ty, obj)) = return_obj {
values.push((ty, obj.to_basic_value_enum(ctx, generator, ty).unwrap()));
}
for val in (*globals).values() { for val in (*globals).values() {
let val = val.as_ref(py); let val = val.as_ref(py);
let ty = inner_resolver.get_obj_type( let ty = inner_resolver.get_obj_type(
@ -717,7 +1089,7 @@ pub fn attributes_writeback(
let args: Vec<_> = let args: Vec<_> =
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) = if let Err(e) =
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator) rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator, true)
{ {
return Ok(Err(e)); return Ok(Err(e));
} }
@ -727,9 +1099,9 @@ pub fn attributes_writeback(
Ok(()) Ok(())
} }
pub fn rpc_codegen_callback() -> Arc<GenCall> { pub fn rpc_codegen_callback(is_async: bool) -> Arc<GenCall> {
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
rpc_codegen_callback_fn(ctx, obj, fun, args, generator) rpc_codegen_callback_fn(ctx, obj, fun, args, generator, is_async)
}))) })))
} }
@ -798,7 +1170,8 @@ fn polymorphic_print<'ctx>(
ctx.module.add_function(fn_name, fn_t, None) ctx.module.add_function(fn_name, fn_t, None)
}); });
let fmt = ctx.gen_string(generator, &fmt).get_field(generator, ctx.ctx, |f| f.base).value; let fmt = ctx.gen_string(generator, fmt);
let fmt = unsafe { fmt.get_field_at_index_unchecked(0) }.into_pointer_value();
ctx.builder ctx.builder
.build_call( .build_call(
@ -878,24 +1251,20 @@ fn polymorphic_print<'ctx>(
fmt.push_str("%.*s"); fmt.push_str("%.*s");
let true_str = ctx.gen_string(generator, "True"); let true_str = ctx.gen_string(generator, "True");
let true_data =
let true_data = true_str.get_field(generator, ctx.ctx, |f| f.base); unsafe { true_str.get_field_at_index_unchecked(0) }.into_pointer_value();
let true_len = true_str.get_field(generator, ctx.ctx, |f| f.len); let true_len = unsafe { true_str.get_field_at_index_unchecked(1) }.into_int_value();
let false_str = ctx.gen_string(generator, "False"); let false_str = ctx.gen_string(generator, "False");
let false_data =
let false_data = false_str.get_field(generator, ctx.ctx, |f| f.base); unsafe { false_str.get_field_at_index_unchecked(0) }.into_pointer_value();
let false_len = false_str.get_field(generator, ctx.ctx, |f| f.len); let false_len =
unsafe { false_str.get_field_at_index_unchecked(1) }.into_int_value();
let bool_val = generator.bool_to_i1(ctx, value.into_int_value()); let bool_val = generator.bool_to_i1(ctx, value.into_int_value());
args.extend([ args.extend([
ctx.builder ctx.builder.build_select(bool_val, true_len, false_len, "").unwrap(),
.build_select(bool_val, true_len.value, false_len.value, "") ctx.builder.build_select(bool_val, true_data, false_data, "").unwrap(),
.unwrap(),
ctx.builder
.build_select(bool_val, true_data.value, false_data.value, "")
.unwrap(),
]); ]);
} }
@ -946,7 +1315,8 @@ fn polymorphic_print<'ctx>(
fmt.push('['); fmt.push('[');
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);
let val = ListValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); let val =
ListValue::from_pointer_value(value.into_pointer_value(), llvm_usize, None);
let len = val.load_size(ctx, None); let len = val.load_size(ctx, None);
let last = let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
@ -998,12 +1368,18 @@ fn polymorphic_print<'ctx>(
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
fmt.push_str("array(["); fmt.push_str("array([");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);
let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); let val = NDArrayValue::from_pointer_value(
let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); value.into_pointer_value(),
llvm_elem_ty,
llvm_usize,
None,
);
let len = call_ndarray_calc_size(generator, ctx, &val.shape(), (None, None));
let last = let last =
ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap();
@ -1056,7 +1432,7 @@ fn polymorphic_print<'ctx>(
fmt.push_str("range("); fmt.push_str("range(");
flush(ctx, generator, &mut fmt, &mut args); flush(ctx, generator, &mut fmt, &mut args);
let val = RangeValue::from_ptr_val(value.into_pointer_value(), None); let val = RangeValue::from_pointer_value(value.into_pointer_value(), None);
let (start, stop, step) = destructure_range(ctx, val); let (start, stop, step) = destructure_range(ctx, val);

View File

@ -1,10 +1,4 @@
#![deny( #![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow( #![allow(
unsafe_op_in_unsafe_fn, unsafe_op_in_unsafe_fn,
@ -16,65 +10,65 @@
clippy::wildcard_imports clippy::wildcard_imports
)] )]
use std::collections::{HashMap, HashSet}; use std::{
use std::fs; collections::{HashMap, HashSet},
use std::io::Write; fs,
use std::process::Command; io::Write,
use std::rc::Rc; process::Command,
use std::sync::Arc; rc::Rc,
sync::Arc,
use inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::{Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::irrt::setup_irrt_exceptions;
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{into_var_map, TypeEnum, Unifier, VarMap};
use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use pyo3::{
create_exception, exceptions,
prelude::*,
types::{PyBytes, PyDict, PyNone, PySet},
};
use tempfile::{self, TempDir};
use nac3core::{ use nac3core::{
codegen::irrt::load_irrt, codegen::{
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry,
},
inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::{FlagBehavior, Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
},
nac3parser::{
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
parser::parse_program,
},
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{ toplevel::{
builtins::get_exn_constructor,
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer}, composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef, DefinitionId, GenCall, TopLevelDef,
}, },
typecheck::typedef::{FunSignature, FuncArg}, typecheck::{
typecheck::{type_inferencer::PrimitiveStore, typedef::Type}, type_inferencer::PrimitiveStore,
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
}; };
use nac3ld::Linker; use nac3ld::Linker;
use crate::{ use codegen::{
codegen::{ attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
},
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
}; };
use tempfile::{self, TempDir}; use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
use timeline::TimeFns;
mod codegen; mod codegen;
mod symbol_resolver; mod symbol_resolver;
mod timeline; mod timeline;
use timeline::TimeFns;
#[derive(PartialEq, Clone, Copy)] #[derive(PartialEq, Clone, Copy)]
enum Isa { enum Isa {
Host, Host,
@ -148,14 +142,32 @@ impl Nac3 {
module: &PyObject, module: &PyObject,
registered_class_ids: &HashSet<u64>, registered_class_ids: &HashSet<u64>,
) -> PyResult<()> { ) -> PyResult<()> {
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file, source) =
let module: &PyAny = module.extract(py)?; Python::with_gil(|py| -> PyResult<(String, String, String)> {
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?)) let module: &PyAny = module.extract(py)?;
})?; let source_file = module.getattr("__file__");
let (source_file, source) = if let Ok(source_file) = source_file {
let source_file = source_file.extract()?;
(
source_file,
fs::read_to_string(source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!(
"failed to read input file: {e}"
))
})?,
)
} else {
// kernels submitted by content have no file
// but still can provide source by StringLoader
let get_src_fn = module
.getattr("__loader__")?
.extract::<PyObject>()?
.getattr(py, "get_source")?;
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
};
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
})?;
let source = fs::read_to_string(&source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
})?;
let parser_result = parse_program(&source, source_file.into()) let parser_result = parse_program(&source, source_file.into())
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?; .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
@ -195,10 +207,8 @@ impl Nac3 {
body.retain(|stmt| { body.retain(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let Some(id) = decorator_id_string(decorator) {
id.to_string() == "kernel" id == "kernel" || id == "portable" || id == "rpc"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else { } else {
false false
} }
@ -211,9 +221,8 @@ impl Nac3 {
} }
StmtKind::FunctionDef { ref decorator_list, .. } => { StmtKind::FunctionDef { ref decorator_list, .. } => {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let Some(id) = decorator_id_string(decorator) {
let id = id.to_string(); id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else { } else {
false false
} }
@ -449,7 +458,6 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: module.clone(), module: module.clone(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
@ -480,9 +488,25 @@ impl Nac3 {
match &stmt.node { match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => { StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { if decorator_list
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); .iter()
rpc_ids.push((None, def_id)); .any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
{
store_fun
.call1(
py,
(
def_id.0.into_py(py),
module.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
rpc_ids.push((None, def_id, is_async));
} }
} }
StmtKind::ClassDef { name, body, .. } => { StmtKind::ClassDef { name, body, .. } => {
@ -490,19 +514,26 @@ impl Nac3 {
let class_obj = module.getattr(py, class_name.as_str()).unwrap(); let class_obj = module.getattr(py, class_name.as_str()).unwrap();
for stmt in body { for stmt in body {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node { if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { if decorator_list.iter().any(|decorator| {
decorator_id_string(decorator) == Some("rpc".to_string())
}) {
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
if name == &"__init__".into() { if name == &"__init__".into() {
return Err(CompileError::new_err(format!( return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})", "compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
class_name, stmt.location class_name, stmt.location
))); )));
} }
rpc_ids.push((Some((class_obj.clone(), *name)), def_id)); rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
} }
} }
} }
} }
_ => () _ => (),
} }
let id = *name_to_pyid.get(&name).unwrap(); let id = *name_to_pyid.get(&name).unwrap();
@ -541,7 +572,6 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
id_to_primitive: RwLock::default(), id_to_primitive: RwLock::default(),
field_to_val: RwLock::default(), field_to_val: RwLock::default(),
@ -559,9 +589,8 @@ impl Nac3 {
.unwrap(); .unwrap();
// Process IRRT // Process IRRT
let context = inkwell::context::Context::create(); let context = Context::create();
let irrt = load_irrt(&context); let irrt = load_irrt(&context, resolver.as_ref());
setup_irrt_exceptions(&context, &irrt, resolver.as_ref());
let fun_signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
@ -600,13 +629,12 @@ impl Nac3 {
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
{ {
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
for (class_data, id) in &rpc_ids { for (class_data, id, is_async) in &rpc_ids {
let mut def = defs[id.0].write(); let mut def = defs[id.0].write();
match &mut *def { match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => { TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback = Some(rpc_codegen.clone()); *codegen_callback = Some(rpc_codegen_callback(*is_async));
} }
TopLevelDef::Class { methods, .. } => { TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap(); let (class_def, method_name) = class_data.as_ref().unwrap();
@ -617,7 +645,7 @@ impl Nac3 {
if let TopLevelDef::Function { codegen_callback, .. } = if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write() &mut *defs[id.0].write()
{ {
*codegen_callback = Some(rpc_codegen.clone()); *codegen_callback = Some(rpc_codegen_callback(*is_async));
store_fun store_fun
.call1( .call1(
py, py,
@ -632,6 +660,11 @@ impl Nac3 {
} }
} }
} }
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @rpc annotation on global variable",
)))
}
} }
} }
} }
@ -652,33 +685,12 @@ impl Nac3 {
let task = CodeGenTask { let task = CodeGenTask {
subst: Vec::default(), subst: Vec::default(),
symbol_name: "__modinit__".to_string(), symbol_name: "__modinit__".to_string(),
body: instance.body,
signature,
resolver: resolver.clone(),
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
id: 0,
};
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature = store.from_signature(
&mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask {
subst: Vec::default(),
symbol_name: "attributes_writeback".to_string(),
body: Arc::new(Vec::default()), body: Arc::new(Vec::default()),
signature, signature,
resolver, resolver,
store, store,
unifier_index: instance.unifier_id, unifier_index: instance.unifier_id,
calls: Arc::new(HashMap::default()), calls: instance.calls,
id: 0, id: 0,
}; };
@ -691,7 +703,7 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = Context::create() let size_t = context
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None) .ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width(); .get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 }; let num_threads = if is_multithreaded() { 4 } else { 1 };
@ -702,19 +714,27 @@ impl Nac3 {
.collect(); .collect();
let membuffer = membuffers.clone(); let membuffer = membuffers.clone();
let mut has_return = false;
py.allow_threads(|| { py.allow_threads(|| {
let (registry, handles) = let (registry, handles) =
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f); WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
let mut generator = let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns);
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); let context = Context::create();
let context = inkwell::context::Context::create(); let module = context.create_module("main");
let module = context.create_module("attributes_writeback");
let target_machine = self.llvm_options.create_target_machine().unwrap(); let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout()); module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple()); module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag(
"Debug Info Version",
FlagBehavior::Warning,
context.i32_type().const_int(3, false),
);
module.add_basic_value_flag(
"Dwarf Version",
FlagBehavior::Warning,
context.i32_type().const_int(4, false),
);
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl( let (_, module, _) = gen_func_impl(
&context, &context,
@ -722,9 +742,27 @@ impl Nac3 {
&registry, &registry,
builder, builder,
module, module,
attributes_writeback_task, task,
|generator, ctx| { |generator, ctx| {
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes) assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement");
let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else {
unreachable!("toplevel statement must be an expression")
};
let ExprKind::Call { .. } = expr.node else {
unreachable!("toplevel expression must be a function call")
};
let return_obj =
generator.gen_expr(ctx, expr)?.map(|value| (expr.custom.unwrap(), value));
has_return = return_obj.is_some();
registry.wait_tasks_complete(handles);
attributes_writeback(
ctx,
generator,
inner_resolver.as_ref(),
&host_attributes,
return_obj,
)
}, },
) )
.unwrap(); .unwrap();
@ -733,35 +771,23 @@ impl Nac3 {
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}); });
embedding_map.setattr("expects_return", has_return).unwrap();
// Link all modules into `main`. // Link all modules into `main`.
let buffers = membuffers.lock(); let buffers = membuffers.lock();
let main = context let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(
buffers.last().unwrap(),
"main",
))
.unwrap(); .unwrap();
for buffer in buffers.iter().skip(1) { for buffer in buffers.iter().rev().skip(1) {
let other = context let other = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap(); .unwrap();
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?; main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
} }
let builder = context.create_builder();
let modinit_return = main
.get_function("__modinit__")
.unwrap()
.get_last_basic_block()
.unwrap()
.get_terminator()
.unwrap();
builder.position_before(&modinit_return);
builder
.build_call(
main.get_function("attributes_writeback").unwrap(),
&[],
"attributes_writeback",
)
.unwrap();
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?; main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
let mut function_iter = main.get_first_function(); let mut function_iter = main.get_first_function();
@ -848,6 +874,41 @@ impl Nac3 {
} }
} }
/// Retrieves the Name.id from a decorator, supports decorators with arguments.
fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
if let ExprKind::Name { id, .. } = decorator.node {
// Bare decorator
return Some(id.to_string());
} else if let ExprKind::Call { func, .. } = &decorator.node {
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
// need to extract the id from within.
if let ExprKind::Name { id, .. } = func.node {
return Some(id.to_string());
}
}
None
}
/// Retrieves flags from a decorator, if any.
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
let mut flags = vec![];
if let ExprKind::Call { keywords, .. } = &decorator.node {
for keyword in keywords {
if keyword.node.arg != Some("flags".into()) {
continue;
}
if let ExprKind::Set { elts } = &keyword.node.value.node {
for elt in elts {
if let ExprKind::Constant { value, .. } = &elt.node {
flags.push(value.clone());
}
}
}
}
}
flags
}
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> { fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
let linker_args = vec![ let linker_args = vec![
"-shared".to_string(), "-shared".to_string(),
@ -1029,7 +1090,12 @@ impl Nac3 {
}) })
} }
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { fn analyze(
&mut self,
functions: &PySet,
classes: &PySet,
content_modules: &PySet,
) -> PyResult<()> {
let (modules, class_ids) = let (modules, class_ids) =
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> { Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new(); let mut modules: HashMap<u64, PyObject> = HashMap::new();
@ -1039,14 +1105,22 @@ impl Nac3 {
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
for function in functions { for function in functions {
let module = getmodule_fn.call1((function,))?.extract()?; let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module); if !module.is_none(py) {
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
} }
for class in classes { for class in classes {
let module = getmodule_fn.call1((class,))?.extract()?; let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module); if !module.is_none(py) {
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
class_ids.insert(id_fn.call1((class,))?.extract()?); class_ids.insert(id_fn.call1((class,))?.extract()?);
} }
for module in content_modules {
let module: PyObject = module.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
Ok((modules, class_ids)) Ok((modules, class_ids))
})?; })?;

View File

@ -1,14 +1,30 @@
use inkwell::{ use std::{
types::{BasicType, BasicTypeEnum}, collections::{HashMap, HashSet},
values::BasicValueEnum, sync::{
AddressSpace, atomic::{AtomicBool, Ordering::Relaxed},
Arc,
},
}; };
use itertools::Itertools; use itertools::Itertools;
use parking_lot::RwLock;
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
classes::{NDArrayType, ProxyType}, types::{NDArrayType, ProxyType},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
inkwell::{
module::Linkage,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
},
nac3parser::ast::{self, StrRef},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{ toplevel::{
helper::PrimDef, helper::PrimDef,
@ -20,21 +36,8 @@ use nac3core::{
typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap}, typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap},
}, },
}; };
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use pyo3::{
types::{PyDict, PyTuple},
PyAny, PyObject, PyResult, Python,
};
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc,
},
};
use crate::PrimitivePythonId; use super::PrimitivePythonId;
pub enum PrimitiveValue { pub enum PrimitiveValue {
I32(i32), I32(i32),
@ -79,7 +82,6 @@ pub struct InnerResolver {
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>, pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>, pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>, pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
pub class_names: Mutex<HashMap<StrRef, Type>>,
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>, pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
pub primitive_ids: PrimitivePythonId, pub primitive_ids: PrimitivePythonId,
@ -133,6 +135,8 @@ impl StaticValue for PythonValue {
format!("{}_const", self.id).as_str(), format!("{}_const", self.id).as_str(),
); );
global.set_constant(true); global.set_constant(true);
// Set linkage of global to private to avoid name collisions
global.set_linkage(Linkage::Private);
global.set_initializer(&ctx.ctx.const_struct( global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(u64::from(id), false).into()], &[ctx.ctx.i32_type().const_int(u64::from(id), false).into()],
false, false,
@ -163,7 +167,7 @@ impl StaticValue for PythonValue {
PrimitiveValue::Bool(val) => { PrimitiveValue::Bool(val) => {
ctx.ctx.i8_type().const_int(u64::from(*val), false).into() ctx.ctx.i8_type().const_int(u64::from(*val), false).into()
} }
PrimitiveValue::Str(val) => ctx.ctx.const_string(val.as_bytes(), true).into(), PrimitiveValue::Str(val) => ctx.gen_string(generator, val).into(),
}); });
} }
if let Some(global) = ctx.module.get_global(&self.id.to_string()) { if let Some(global) = ctx.module.get_global(&self.id.to_string()) {
@ -977,7 +981,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ {
let val: String = obj.extract().unwrap(); let val: String = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone()));
Ok(Some(ctx.ctx.const_string(val.as_bytes(), true).into())) Ok(Some(ctx.gen_string(generator, val).into()))
} else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 {
let val: f64 = obj.extract().unwrap(); let val: f64 = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val));
@ -1092,7 +1096,7 @@ impl InnerResolver {
if self.global_value_ids.read().contains_key(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global( ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(), ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
) )
@ -1186,20 +1190,24 @@ impl InnerResolver {
data_global.set_initializer(&data); data_global.set_initializer(&data);
// create a global for the ndarray object and initialize it // create a global for the ndarray object and initialize it
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ let value = ndarray_llvm_ty
llvm_usize.const_int(ndarray_ndims, false).into(), .as_base_type()
shape_global .get_element_type()
.as_pointer_value() .into_struct_type()
.const_cast(llvm_usize.ptr_type(AddressSpace::default())) .const_named_struct(&[
.into(), llvm_usize.const_int(ndarray_ndims, false).into(),
data_global shape_global
.as_pointer_value() .as_pointer_value()
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default())) .const_cast(llvm_usize.ptr_type(AddressSpace::default()))
.into(), .into(),
]); data_global
.as_pointer_value()
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
.into(),
]);
let ndarray = ctx.module.add_global( let ndarray = ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(), ndarray_llvm_ty.as_base_type().get_element_type().into_struct_type(),
Some(AddressSpace::default()), Some(AddressSpace::default()),
&id_str, &id_str,
); );
@ -1466,6 +1474,7 @@ impl SymbolResolver for Resolver {
&self, &self,
id: StrRef, id: StrRef,
_: &mut CodeGenContext<'ctx, '_>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
let sym_value = { let sym_value = {
let id_to_val = self.0.id_to_pyval.read(); let id_to_val = self.0.id_to_pyval.read();

View File

@ -1,9 +1,12 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
};
use itertools::Either; use itertools::Either;
use nac3core::codegen::CodeGenContext;
use nac3core::{
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
};
/// Functions for manipulating the timeline. /// Functions for manipulating the timeline.
pub trait TimeFns { pub trait TimeFns {
@ -31,7 +34,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -80,7 +83,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -109,7 +112,7 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -207,7 +210,7 @@ impl TimeFns for NowPinningTimeFns {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
@ -258,7 +261,7 @@ impl TimeFns for NowPinningTimeFns {
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap(); let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx let now_hiptr = ctx
.builder .builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();

View File

@ -10,7 +10,6 @@ constant-optimization = ["fold"]
fold = [] fold = []
[dependencies] [dependencies]
lazy_static = "1.5"
parking_lot = "0.12" parking_lot = "0.12"
string-interner = "0.17" string-interner = "0.17"
fxhash = "0.2" fxhash = "0.2"

View File

@ -5,14 +5,12 @@ pub use crate::location::Location;
use fxhash::FxBuildHasher; use fxhash::FxBuildHasher;
use parking_lot::{Mutex, MutexGuard}; use parking_lot::{Mutex, MutexGuard};
use std::{cell::RefCell, collections::HashMap, fmt}; use std::{cell::RefCell, collections::HashMap, fmt, sync::LazyLock};
use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner}; use string_interner::{symbol::SymbolU32, DefaultBackend, StringInterner};
pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>; pub type Interner = StringInterner<DefaultBackend, FxBuildHasher>;
lazy_static! { static INTERNER: LazyLock<Mutex<Interner>> =
static ref INTERNER: Mutex<Interner> = LazyLock::new(|| Mutex::new(StringInterner::with_hasher(FxBuildHasher::default())));
Mutex::new(StringInterner::with_hasher(FxBuildHasher::default()));
}
thread_local! { thread_local! {
static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default(); static LOCAL_INTERNER: RefCell<HashMap<String, StrRef>> = RefCell::default();

View File

@ -1,10 +1,4 @@
#![deny( #![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow( #![allow(
clippy::missing_errors_doc, clippy::missing_errors_doc,
@ -14,9 +8,6 @@
clippy::wildcard_imports clippy::wildcard_imports
)] )]
#[macro_use]
extern crate lazy_static;
mod ast_gen; mod ast_gen;
mod constant; mod constant;
#[cfg(feature = "fold")] #[cfg(feature = "fold")]

View File

@ -1,26 +1,29 @@
[features]
test = []
[package] [package]
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2021" edition = "2021"
[features]
default = ["derive"]
derive = ["dep:nac3core_derive"]
no-escape-analysis = []
[dependencies] [dependencies]
itertools = "0.13" itertools = "0.13"
crossbeam = "0.8" crossbeam = "0.8"
indexmap = "2.2" indexmap = "2.6"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.8" rayon = "1.10"
nac3core_derive = { path = "nac3core_derive", optional = true }
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
strum = "0.26.2" strum = "0.26"
strum_macros = "0.26.4" strum_macros = "0.26"
[dependencies.inkwell] [dependencies.inkwell]
version = "0.4" version = "0.5"
default-features = false default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"] features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies] [dev-dependencies]
test-case = "1.2.0" test-case = "1.2.0"

View File

@ -1,67 +1,63 @@
use regex::Regex;
use std::{ use std::{
env, env,
fs::File, fs::File,
io::Write, io::Write,
path::{Path, PathBuf}, path::Path,
process::{Command, Stdio}, process::{Command, Stdio},
}; };
const CMD_IRRT_CLANG: &str = "clang-irrt"; use regex::Regex;
const CMD_IRRT_CLANG_TEST: &str = "clang-irrt-test";
const CMD_IRRT_LLVM_AS: &str = "llvm-as-irrt";
fn get_out_dir() -> PathBuf { fn main() {
PathBuf::from(env::var("OUT_DIR").unwrap()) let out_dir = env::var("OUT_DIR").unwrap();
} let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("irrt");
fn get_irrt_dir() -> &'static Path { let irrt_cpp_path = irrt_dir.join("irrt.cpp");
Path::new("irrt")
}
/// Compile `irrt.cpp` for use in `src/codegen`
fn compile_irrt_cpp() {
let out_dir = get_out_dir();
let irrt_dir = get_irrt_dir();
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
let irrt_cpp_path = irrt_dir.join("irrt.cpp"); let mut flags: Vec<&str> = vec![
"--target=wasm32",
let mut flags = vec![]; "-x",
flags.push("--target=wasm32"); "c++",
flags.extend(&["-x", "c++"]); "-std=c++20",
flags.extend(&["-fno-discard-value-names", "-fno-exceptions", "-fno-rtti"]); "-fno-discard-value-names",
flags.push("-emit-llvm"); "-fno-exceptions",
flags.push("-S"); "-fno-rtti",
flags.extend(&["-Wall", "-Wextra"]); "-emit-llvm",
flags.extend(&["-o", "-"]); "-S",
flags.extend(&["-I", irrt_dir.to_str().unwrap()]); "-Wall",
flags.push(irrt_cpp_path.to_str().unwrap()); "-Wextra",
"-o",
"-",
"-I",
irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(),
];
match env::var("PROFILE").as_deref() { match env::var("PROFILE").as_deref() {
Ok("debug") => { Ok("debug") => {
flags.push("-O0"); flags.push("-O0");
flags.push("-DIRRT_DEBUG"); flags.push("-DIRRT_DEBUG_ASSERT");
} }
Ok("release") => { Ok("release") => {
flags.push("-O3"); flags.push("-O3");
} }
flavor => panic!("Unknown or missing build flavor {flavor:?}"), flavor => panic!("Unknown or missing build flavor {flavor:?}"),
}; }
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes // Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap()); println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output // Compile IRRT and capture the LLVM IR output
let output = Command::new(CMD_IRRT_CLANG) let output = Command::new("clang-irrt")
.args(flags) .args(flags)
.output() .output()
.map(|o| { .inspect(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap()); assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
}) })
.unwrap(); .unwrap();
@ -102,9 +98,7 @@ fn compile_irrt_cpp() {
file.write_all(filtered_output.as_bytes()).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap();
} }
// Assemble the emitted and filtered IR to .bc let mut llvm_as = Command::new("llvm-as-irrt")
// That .bc will be integrated into nac3core's codegen
let mut llvm_as = Command::new(CMD_IRRT_LLVM_AS)
.stdin(Stdio::piped()) .stdin(Stdio::piped())
.arg("-o") .arg("-o")
.arg(out_dir.join("irrt.bc")) .arg(out_dir.join("irrt.bc"))
@ -113,48 +107,3 @@ fn compile_irrt_cpp() {
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success()); assert!(llvm_as.wait().unwrap().success());
} }
/// Compile `irrt_test.cpp` for testing
fn compile_irrt_test_cpp() {
let out_dir = get_out_dir();
let irrt_dir = get_irrt_dir();
let exe_path = out_dir.join("irrt_test.out"); // Output path of the compiled test executable
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
let flags: &[&str] = &[
irrt_test_cpp_path.to_str().unwrap(),
"-x",
"c++",
"-I",
irrt_dir.to_str().unwrap(),
"-g",
"-fno-discard-value-names",
"-O0",
"-Wall",
"-Wextra",
"-Werror=return-type",
"-lm", // for `tgamma()`, `lgamma()`
"-o",
exe_path.to_str().unwrap(),
];
Command::new(CMD_IRRT_CLANG_TEST)
.args(flags)
.output()
.map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
})
.unwrap();
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
}
fn main() {
compile_irrt_cpp();
// https://github.com/rust-lang/cargo/issues/2549
// `cargo test -F test` to also build `irrt_test.cpp
if cfg!(feature = "test") {
compile_irrt_test_cpp();
}
}

View File

@ -1,10 +1,5 @@
#define IRRT_DEFINE_TYPEDEF_INTS #include "irrt/exception.hpp"
#include <irrt_everything.hpp> #include "irrt/list.hpp"
#include "irrt/math.hpp"
/* #include "irrt/ndarray.hpp"
* All IRRT implementations. #include "irrt/slice.hpp"
*
* We don't have pre-compiled objects, so we are writing all implementations in
* headers and concatenate them with `#include` into one massive source file that
* contains all the IRRT stuff.
*/

View File

@ -1,349 +0,0 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/util.hpp>
// NDArray indices are always `uint32_t`.
using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a
// range/slice is always `int32_t`.
using SliceIndex = int32_t;
namespace {
// adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template <typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
template <typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len,
SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template <typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims,
SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template <typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims,
const NDIndexInt* indices,
SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template <typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, SizeT lhs_ndims,
const SizeT* rhs_dims, SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz =
i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz =
i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template <typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end,
const SliceIndex step) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else
// len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(
SliceIndex dest_start, SliceIndex dest_end, SliceIndex dest_step,
uint8_t* dest_arr, SliceIndex dest_arr_len, SliceIndex src_start,
SliceIndex src_end, SliceIndex src_step, uint8_t* src_arr,
SliceIndex src_arr_len, const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support
* extending list
*/
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of
* the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len =
(src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len =
(dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(dest_arr + dest_start * size,
src_arr + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr) &&
!(max(dest_start, dest_end) < min(src_start, src_end) ||
max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
uint8_t* tmp =
reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous
* alloca */
__builtin_memcpy(dest_arr + dest_ind * size,
src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size + size + size + size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) { return __builtin_isinf(x); }
int32_t __nac3_isnan(double x) { return __builtin_isnan(x); }
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len,
uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx,
end_idx);
}
uint64_t __nac3_ndarray_calc_size64(const uint64_t* list_data,
uint64_t list_len, uint64_t begin_idx,
uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx,
end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims,
uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims,
uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t __nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims,
const NDIndexInt* indices,
uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims,
const NDIndexInt* indices,
uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices,
num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, uint32_t lhs_ndims,
const uint32_t* rhs_dims, uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims,
rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims, uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims,
rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx,
out_idx);
}
} // extern "C"

View File

@ -1,9 +1,9 @@
#pragma once #pragma once
#include <irrt/int_defs.hpp> #include "irrt/int_types.hpp"
template <typename SizeT> template<typename SizeT>
struct CSlice { struct CSlice {
uint8_t* base; void* base;
SizeT len; SizeT len;
}; };

View File

@ -1,15 +1,25 @@
#pragma once #pragma once
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \ // Set in nac3core/build.rs
raise_exception(SizeT, EXN_ASSERTION_ERROR, \ #ifdef IRRT_DEBUG_ASSERT
"IRRT debug assert failed: " msg, param1, param2, param3); #define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#define debug_assert_eq(SizeT, lhs, rhs) \ #define raise_debug_assert(SizeT, msg, param1, param2, param3) \
if (IRRT_DEBUG_ASSERT_BOOL && (lhs) != (rhs)) { \ raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
#define debug_assert_eq(SizeT, lhs, rhs) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
} }
#define debug_assert(SizeT, expr) \ #define debug_assert(SizeT, expr) \
if (IRRT_DEBUG_ASSERT_BOOL && !(expr)) { \ if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \ if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
} }

View File

@ -1,43 +1,37 @@
#pragma once #pragma once
#include <irrt/cslice.hpp> #include "irrt/cslice.hpp"
#include <irrt/int_defs.hpp> #include "irrt/int_types.hpp"
#include <irrt/util.hpp>
/** /**
* @brief The int type of ARTIQ exception IDs. * @brief The int type of ARTIQ exception IDs.
*
* It is always `int32_t`
*/ */
typedef int32_t ExceptionId; using ExceptionId = int32_t;
/* /*
* A set of exceptions IRRT can use. * Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`. * Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
* All exception IDs are initialized by `setup_irrt_exceptions`.
*/ */
#ifdef IRRT_TESTING
// If we are doing IRRT tests (i.e., running `cargo test -F test`), define them with a fake set of IDs.
ExceptionId EXN_INDEX_ERROR = 0;
ExceptionId EXN_VALUE_ERROR = 1;
ExceptionId EXN_ASSERTION_ERROR = 2;
ExceptionId EXN_RUNTIME_ERROR = 3;
ExceptionId EXN_TYPE_ERROR = 4;
#else
extern "C" { extern "C" {
ExceptionId EXN_INDEX_ERROR; ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR; ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR; ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_RUNTIME_ERROR;
ExceptionId EXN_TYPE_ERROR; ExceptionId EXN_TYPE_ERROR;
} }
#endif
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
namespace { namespace {
/** /**
* @brief NAC3's Exception struct * @brief NAC3's Exception struct
*/ */
template <typename SizeT> template<typename SizeT>
struct Exception { struct Exception {
ExceptionId id; ExceptionId id;
CSlice<SizeT> filename; CSlice<SizeT> filename;
@ -47,51 +41,36 @@ struct Exception {
CSlice<SizeT> msg; CSlice<SizeT> msg;
int64_t params[3]; int64_t params[3];
}; };
} // namespace
// Declare/Define `__nac3_raise` constexpr int64_t NO_PARAM = 0;
#ifdef IRRT_TESTING
#include <cstdio>
void __nac3_raise(void* err) {
// TODO: Print the error content?
printf("__nac3_raise called. Exiting...\n");
exit(1);
}
#else
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
#endif
namespace { template<typename SizeT>
const int64_t NO_PARAM = 0; void _raise_exception_helper(ExceptionId id,
const char* filename,
// Helper function to raise an exception with `__nac3_raise` int32_t line,
// Do not use this function directly. See `raise_exception`. const char* function,
template <typename SizeT> const char* msg,
void _raise_exception_helper(ExceptionId id, const char* filename, int32_t line, int64_t param0,
const char* function, const char* msg, int64_t param1,
int64_t param0, int64_t param1, int64_t param2) { int64_t param2) {
Exception<SizeT> e = { Exception<SizeT> e = {
.id = id, .id = id,
.filename = {.base = (uint8_t*)filename, .filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
.len = (int32_t)cstr_utils::length(filename)}, .len = static_cast<SizeT>(__builtin_strlen(filename))},
.line = line, .line = line,
.column = 0, .column = 0,
.function = {.base = (uint8_t*)function, .function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
.len = (int32_t)cstr_utils::length(function)}, .len = static_cast<SizeT>(__builtin_strlen(function))},
.msg = {.base = (uint8_t*)msg, .len = (int32_t)cstr_utils::length(msg)}, .msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
.len = static_cast<SizeT>(__builtin_strlen(msg))},
}; };
e.params[0] = param0; e.params[0] = param0;
e.params[1] = param1; e.params[1] = param1;
e.params[2] = param2; e.params[2] = param2;
__nac3_raise((void*)&e); __nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable(); __builtin_unreachable();
} }
} // namespace
/** /**
* @brief Raise an exception with location details (location in the IRRT source files). * @brief Raise an exception with location details (location in the IRRT source files).
@ -99,25 +78,8 @@ void _raise_exception_helper(ExceptionId id, const char* filename, int32_t line,
* @param id The ID of the exception to raise. * @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message. * @param msg A global constant C-string of the error message.
* *
* `param0` and `param2` are optional format arguments of `msg`. They should be set to * `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused. * `NO_PARAM` to indicate they are unused.
*/ */
#define raise_exception(SizeT, id, msg, param0, param1, param2) \ #define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, \ _raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
param0, param1, param2)
/**
* @brief Throw a dummy error for testing.
*/
template <typename SizeT>
void throw_dummy_error() {
raise_exception(SizeT, EXN_RUNTIME_ERROR, "dummy error", NO_PARAM, NO_PARAM,
NO_PARAM);
}
} // namespace
extern "C" {
void __nac3_throw_dummy_error() { throw_dummy_error<int32_t>(); }
void __nac3_throw_dummy_error64() { throw_dummy_error<int64_t>(); }
}

View File

@ -1,12 +0,0 @@
#pragma once
// This is made toggleable since `irrt_test.cpp` itself would include
// headers that define these typedefs
#ifdef IRRT_DEFINE_TYPEDEF_INTS
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#endif

View File

@ -0,0 +1,27 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-type"
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#pragma clang diagnostic pop
#endif
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -1,56 +1,81 @@
#pragma once #pragma once
#include <irrt/int_defs.hpp> #include "irrt/int_types.hpp"
#include <irrt/slice.hpp> #include "irrt/math_util.hpp"
namespace {
/**
* @brief A list in NAC3.
*
* The `items` field is opaque. You must rely on external contexts to
* know how to interpret it.
*/
template <typename SizeT>
struct List {
uint8_t* items;
SizeT len;
};
namespace list {
template <typename SizeT>
void slice_assign(List<SizeT>* dst, List<SizeT>* src, SizeT itemsize,
UserSlice* user_slice) {
Slice slice = user_slice->indices_checked<SizeT>(dst->len);
// NOTE: Python does not have this restriction.
if (slice.len() != src->len) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"List destination has {} item(s), but source has {} "
"item(s). The lengths must match.",
slice.len(), src->len, NO_PARAM);
}
// TODO: Look into how the original implementation was implemented and optimized.
SizeT dst_i = slice.start;
SizeT src_i = 0;
while (src_i < slice.len()) {
__builtin_memcpy(dst->items + dst_i, src->items + src_i, itemsize);
src_i += 1;
dst_i += slice.step;
}
}
} // namespace list
} // namespace
extern "C" { extern "C" {
void __nac3_list_slice_assign(List<int32_t>* dst, List<int32_t>* src, // Handle list assignment and dropping part of the list when
int32_t itemsize, UserSlice* user_slice) { // both dest_step and src_step are +1.
list::slice_assign(dst, src, itemsize, user_slice); // - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
void* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
void* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
void* tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
} }
} // extern "C"
void __nac3_list_slice_assign64(List<int64_t>* dst, List<int64_t>* src,
int64_t itemsize, UserSlice* user_slice) {
list::slice_assign(dst, src, itemsize, user_slice);
}
}

View File

@ -0,0 +1,93 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
} // namespace

View File

@ -0,0 +1,13 @@
#pragma once
namespace {
template<typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
} // namespace

View File

@ -0,0 +1,144 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
} // namespace

View File

@ -1,119 +0,0 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/list.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace array {
// TODO: Document me
template <typename SizeT>
void set_and_validate_list_shape_helper(SizeT axis, List<SizeT>* list,
SizeT ndims, SizeT* shape) {
if (shape[axis] == -1) {
// Dimension is unspecified. Set it.
shape[axis] = list->len;
} else {
// Dimension is specified. Check.
if (shape[axis] != list->len) {
// Mismatch, throw an error.
// NOTE: NumPy's error message is more complex and needs more PARAMS to display.
raise_exception(SizeT, EXN_VALUE_ERROR,
"The requested array has an inhomogenous shape "
"after {0} dimension(s).",
axis, shape[axis], list->len);
}
}
if (axis + 1 == ndims) {
// `list` has type `list[ItemType]`
// Do nothing
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
set_and_validate_list_shape_helper<SizeT>(axis + 1, lists[i], ndims,
shape);
}
}
}
// TODO: Document me
template <typename SizeT>
void set_and_validate_list_shape(List<SizeT>* list, SizeT ndims, SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
shape[axis] = -1; // Sentinel to say this dimension is unspecified.
}
set_and_validate_list_shape_helper<SizeT>(0, list, ndims, shape);
}
// TODO: Document me
template <typename SizeT>
void write_list_to_array_helper(SizeT axis, SizeT* index, List<SizeT>* list,
NDArray<SizeT>* ndarray) {
debug_assert_eq(SizeT, list->len, ndarray->shape[axis]);
if (IRRT_DEBUG_ASSERT_BOOL) {
if (!ndarray::basic::is_c_contiguous(ndarray)) {
raise_debug_assert(SizeT, "ndarray is not C-contiguous", ndarray->strides[0],
ndarray->strides[1], NO_PARAM);
}
}
if (axis + 1 == ndarray->ndims) {
// `list` has type `list[ItemType]`
// `ndarray` is contiguous, so we can do this, and this is fast.
uint8_t* dst = ndarray->data + (ndarray->itemsize * (*index));
__builtin_memcpy(dst, list->items, ndarray->itemsize * list->len);
*index += list->len;
} else {
// `list` has type `list[list[...]]`
List<SizeT>** lists = (List<SizeT>**)(list->items);
for (SizeT i = 0; i < list->len; i++) {
write_list_to_array_helper<SizeT>(axis + 1, index, lists[i],
ndarray);
}
}
}
// TODO: Document me
template <typename SizeT>
void write_list_to_array(List<SizeT>* list, NDArray<SizeT>* ndarray) {
// done after set_and_validate(list, ndims, shape), list is well-formed
// ndarray->data is allocated and owned
// ndarray->itemsize is set
// ndarray->ndims is set
// ndarray->shape is set
// ndarray->strides is ???
SizeT index = 0;
write_list_to_array_helper<SizeT>((SizeT)0, &index, list, ndarray);
}
} // namespace array
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::array;
void __nac3_array_set_and_validate_list_shape(List<int32_t>* list,
int32_t ndims, int32_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_array_set_and_validate_list_shape64(List<int64_t>* list,
int64_t ndims, int64_t* shape) {
set_and_validate_list_shape(list, ndims, shape);
}
void __nac3_array_write_list_to_array(List<int32_t>* list,
NDArray<int32_t>* ndarray) {
write_list_to_array(list, ndarray);
}
void __nac3_array_write_list_to_array64(List<int64_t>* list,
NDArray<int64_t>* ndarray) {
write_list_to_array(list, ndarray);
}
}

View File

@ -1,345 +0,0 @@
#pragma once
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace basic {
/**
* @brief Asserts that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template <typename SizeT>
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis], NO_PARAM);
}
}
}
/**
* @brief Check two shapes are the same in the context of writing outputting to an ndarray.
*
* This function throws error messages for output shape mismatches.
*/
template <typename SizeT>
void assert_output_shape_same(SizeT ndarray_ndims, const SizeT* ndarray_shape,
SizeT output_ndims, const SizeT* output_shape) {
if (ndarray_ndims != output_ndims) {
// There is no corresponding NumPy error message like this.
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Cannot write output of ndims {0} to an ndarray with ndims {1}",
output_ndims, ndarray_ndims, NO_PARAM);
}
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
if (ndarray_shape[axis] != output_shape[axis]) {
// There is no corresponding NumPy error message like this.
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Mismatched dimensions on axis {0}, output has "
"dimension {1}, but destination ndarray has dimension {2}.",
axis, output_shape[axis], ndarray_shape[axis]);
}
}
}
/**
* @brief Returns the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template <typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++) size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template <typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices,
SizeT nth) {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
SizeT dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template <typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template <typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The returned result
*/
template <typename SizeT>
SizeT len(const NDArray<SizeT>* ndarray) {
// numpy prohibits `__len__` on unsized objects
if (ndarray->ndims == 0) {
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object",
NO_PARAM, NO_PARAM, NO_PARAM);
} else {
return ndarray->shape[0];
}
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see: ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template <typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// Other references:
// - tinynumpy's implementation: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity: https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] !=
ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
return false;
}
}
return true;
}
/**
* @brief Return the pointer to the element indexed by `indices`.
*/
template <typename SizeT>
uint8_t* get_pelement_by_indices(const NDArray<SizeT>* ndarray,
const SizeT* indices) {
uint8_t* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element += indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Convenience function. Like `get_pelement_by_indices` but
* reinterprets the element pointer.
*/
template <typename SizeT, typename T>
T* get_ptr(const NDArray<SizeT>* ndarray, const SizeT* indices) {
return (T*)get_pelement_by_indices(ndarray, indices);
}
/**
* @brief Return the pointer to the nth (0-based) element in a flattened view of `ndarray`.
*
* This function does no bound check.
*/
template <typename SizeT>
uint8_t* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
uint8_t* element = ndarray->data;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
SizeT dim = ndarray->shape[axis];
element += ndarray->strides[axis] * (nth % dim);
nth /= dim;
}
return element;
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape`
* and assuming that the ndarray is fully c-contagious.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template <typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template <typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, uint8_t* pelement,
const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template <typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element,
src_element);
}
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims,
int32_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims,
int64_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
const int32_t* ndarray_shape,
int32_t output_ndims,
const int32_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims,
output_shape);
}
void __nac3_ndarray_util_assert_output_shape_same64(
int64_t ndarray_ndims, const int64_t* ndarray_shape, int64_t output_ndims,
const int64_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims,
output_shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) { return len(ndarray); }
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) { return len(ndarray); }
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
uint8_t* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray,
int32_t nth) {
return get_nth_pelement(ndarray, nth);
}
uint8_t* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray,
int64_t nth) {
return get_nth_pelement(ndarray, nth);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
}

View File

@ -1,171 +0,0 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace {
template <typename SizeT>
struct ShapeEntry {
SizeT ndims;
SizeT* shape;
};
} // namespace
namespace {
namespace ndarray {
namespace broadcast {
/**
* @brief Return true if `src_shape` can broadcast to `dst_shape`.
*
* See https://numpy.org/doc/stable/user/basics.broadcasting.html
*/
template <typename SizeT>
bool can_broadcast_shape_to(SizeT target_ndims, const SizeT* target_shape,
SizeT src_ndims, const SizeT* src_shape) {
if (src_ndims > target_ndims) {
return false;
}
for (SizeT i = 0; i < src_ndims; i++) {
SizeT target_dim = target_shape[target_ndims - i - 1];
SizeT src_dim = src_shape[src_ndims - i - 1];
if (!(src_dim == 1 || target_dim == src_dim)) {
return false;
}
}
return true;
}
/**
* @brief Performs `np.broadcast_shapes(<shapes>)`
*
* @param num_shapes Number of entries in `shapes`
* @param shapes The list of shape to do `np.broadcast_shapes` on.
* @param dst_ndims The length of `dst_shape`.
* `dst_ndims` must be `max([shape.ndims for shape in shapes])`, but the caller has to calculate it/provide it.
* for this function since they should already know in order to allocate `dst_shape` in the first place.
* @param dst_shape The resulting shape. Must be pre-allocated by the caller. This function calculate the result
* of `np.broadcast_shapes` and write it here.
*/
template <typename SizeT>
void broadcast_shapes(SizeT num_shapes, const ShapeEntry<SizeT>* shapes,
SizeT dst_ndims, SizeT* dst_shape) {
for (SizeT dst_axis = 0; dst_axis < dst_ndims; dst_axis++) {
dst_shape[dst_axis] = 1;
}
#ifdef IRRT_DEBUG_ASSERT
SizeT max_ndims_found = 0;
#endif
for (SizeT i = 0; i < num_shapes; i++) {
ShapeEntry<SizeT> entry = shapes[i];
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
debug_assert(SizeT, entry.ndims <= dst_ndims);
#ifdef IRRT_DEBUG_ASSERT
max_ndims_found = max(max_ndims_found, entry.ndims);
#endif
for (SizeT j = 0; j < entry.ndims; j++) {
SizeT entry_axis = entry.ndims - j - 1;
SizeT dst_axis = dst_ndims - j - 1;
SizeT entry_dim = entry.shape[entry_axis];
SizeT dst_dim = dst_shape[dst_axis];
if (dst_dim == 1) {
dst_shape[dst_axis] = entry_dim;
} else if (entry_dim == 1 || entry_dim == dst_dim) {
// Do nothing
} else {
raise_exception(SizeT, EXN_VALUE_ERROR,
"shape mismatch: objects cannot be broadcast "
"to a single shape.",
NO_PARAM, NO_PARAM, NO_PARAM);
}
}
}
// Check pre-condition: `dst_ndims` must be `max([shape.ndims for shape in shapes])`
debug_assert_eq(SizeT, max_ndims_found, dst_ndims);
}
/**
* @brief Perform `np.broadcast_to(<ndarray>, <target_shape>)` and appropriate assertions.
*
* This function attempts to broadcast `src_ndarray` to a new shape defined by `dst_ndarray.shape`,
* and return the result by modifying `dst_ndarray`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, determining the length of `dst_ndarray->shape`
* - `dst_ndarray->shape` must be allocated, and must contain the desired target broadcast shape.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is unchanged.
* - `dst_ndarray->strides` is updated accordingly by how ndarray broadcast_to works.
*/
template <typename SizeT>
void broadcast_to(const NDArray<SizeT>* src_ndarray,
NDArray<SizeT>* dst_ndarray) {
if (!ndarray::broadcast::can_broadcast_shape_to(
dst_ndarray->ndims, dst_ndarray->shape, src_ndarray->ndims,
src_ndarray->shape)) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"operands could not be broadcast together", NO_PARAM,
NO_PARAM, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
for (SizeT i = 0; i < dst_ndarray->ndims; i++) {
SizeT src_axis = src_ndarray->ndims - i - 1;
SizeT dst_axis = dst_ndarray->ndims - i - 1;
if (src_axis < 0 || (src_ndarray->shape[src_axis] == 1 &&
dst_ndarray->shape[dst_axis] != 1)) {
// Freeze the steps in-place
dst_ndarray->strides[dst_axis] = 0;
} else {
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
}
}
} // namespace broadcast
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::broadcast;
void __nac3_ndarray_broadcast_to(NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
broadcast_to(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_to64(NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
broadcast_to(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_broadcast_shapes(int32_t num_shapes,
const ShapeEntry<int32_t>* shapes,
int32_t dst_ndims, int32_t* dst_shape) {
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
void __nac3_ndarray_broadcast_shapes64(int64_t num_shapes,
const ShapeEntry<int64_t>* shapes,
int64_t dst_ndims, int64_t* dst_shape) {
broadcast_shapes(num_shapes, shapes, dst_ndims, dst_shape);
}
}

View File

@ -1,44 +0,0 @@
#pragma once
namespace {
/**
* @brief The NDArray object
*
* The official numpy implementations: https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
*/
template <typename SizeT>
struct NDArray {
/**
* @brief The underlying data this `ndarray` is pointing to.
*
* Must be set to `nullptr` to indicate that this NDArray's `data` is uninitialized.
*/
uint8_t* data;
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values.
*/
SizeT* strides;
};
} // namespace

View File

@ -1,221 +0,0 @@
#pragma once
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
namespace {
typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* `data` points to a `SliceIndex`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* `data` points to a `UserRange`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
/**
* @brief `np.newaxis` / `None`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
/**
* @brief `Ellipsis` / `...`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
/**
* @brief An index used in ndarray indexing
*/
struct NDIndex {
/**
* @brief Enum tag to specify the type of index.
*
* Please see comments of each enum constant.
*/
NDIndexType type;
/**
* @brief The accompanying data associated with `type`.
*
* Please see comments of each enum constant.
*/
uint8_t* data;
};
} // namespace
namespace {
namespace ndarray {
namespace indexing {
/**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
*
* This is function very similar to performing `dst_ndarray = src_ndarray[indexes]` in Python (where the variables
* can all be found in the parameter of this function).
*
* In other words, this function takes in an ndarray (`src_ndarray`), index it with `indexes`, and return the
* indexed array (by writing the result to `dst_ndarray`).
*
* This function also does proper assertions on `indexes`.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
* indexing `src_ndarray` with `indexes`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
*
* @param indexes Indexes to index `src_ndarray`, ordered in the same way you would write them in Python.
* @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/
template <typename SizeT>
void index(SizeT num_indexes, const NDIndex* indexes,
const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// First, validate `indexes`.
// Expected value of `dst_ndarray->ndims`.
SizeT expected_dst_ndims = src_ndarray->ndims;
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
SizeT num_indexed = 0;
// There may be ellipsis `...` in `indexes`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0;
for (SizeT i = 0; i < num_indexes; i++) {
if (indexes[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
expected_dst_ndims--;
num_indexed++;
} else if (indexes[i].type == ND_INDEX_TYPE_SLICE) {
num_indexed++;
} else if (indexes[i].type == ND_INDEX_TYPE_NEWAXIS) {
expected_dst_ndims++;
} else if (indexes[i].type == ND_INDEX_TYPE_ELLIPSIS) {
num_ellipsis++;
if (num_ellipsis > 1) {
raise_exception(
SizeT, EXN_INDEX_ERROR,
"an index can only have a single ellipsis ('...')",
NO_PARAM, NO_PARAM, NO_PARAM);
}
} else {
__builtin_unreachable();
}
}
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
if (src_ndarray->ndims - num_indexed < 0) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
src_ndarray->ndims, num_indexes, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Reference code: https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (SliceIndex i = 0; i < num_indexes; i++) {
const NDIndex* index = &indexes[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
SliceIndex input = *((SliceIndex*)index->data);
SliceIndex k = slice::resolve_index_in_length(
src_ndarray->shape[src_axis], input);
if (k == slice::OUT_OF_BOUNDS) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} "
"with size {2}",
input, src_axis, src_ndarray->shape[src_axis]);
}
dst_ndarray->data += k * src_ndarray->strides[src_axis];
src_axis++;
} else if (index->type == ND_INDEX_TYPE_SLICE) {
UserSlice* input = (UserSlice*)index->data;
Slice slice =
input->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data +=
(SizeT)slice.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] =
((SizeT)slice.step) * src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = (SizeT)slice.len();
dst_axis++;
src_axis++;
} else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
dst_ndarray->strides[dst_axis] = 0;
dst_ndarray->shape[dst_axis] = 1;
dst_axis++;
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
// The number of ':' entries this '...' implies.
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
for (SizeT j = 0; j < ellipsis_size; j++) {
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_axis++;
src_axis++;
}
} else {
__builtin_unreachable();
}
}
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
}
} // namespace indexing
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::indexing;
void __nac3_ndarray_index(int32_t num_indexes, NDIndex* indexes,
NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
index(num_indexes, indexes, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(int64_t num_indexes, NDIndex* indexes,
NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
index(num_indexes, indexes, src_ndarray, dst_ndarray);
}
}

View File

@ -1,118 +0,0 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
/**
* @brief Helper struct to enumerate through all indices under a shape.
*
* i.e., If `shape` is `[3, 2]`, by repeating `next()`, then you get:
* - `[0, 0]`
* - `[0, 1]`
* - `[1, 0]`
* - `[1, 1]`
* - `[2, 0]`
* - `[2, 1]`
* - end.
*
* Interesting cases:
* - If ndims == 0, there is one enumeration.
* - If shape contains zeroes, there are no enumerations.
*/
template <typename SizeT>
struct NDIter {
SizeT ndims;
SizeT* shape;
SizeT* strides;
/**
* @brief The current indices.
*
* Must be allocated by the caller.
*/
SizeT* indices;
/**
* @brief The nth (0-based) index of the current indices.
*/
SizeT nth;
/**
* @brief Pointer to the current element.
*/
uint8_t* element;
/**
* @brief The product of shape.
*/
SizeT size;
// TODO:: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
// Maybe LLVM is clever and knows how to optimize.
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, uint8_t* element,
SizeT* indices) {
this->ndims = ndims;
this->shape = shape;
this->strides = strides;
this->indices = indices;
this->element = element;
// Compute size and backstrides
this->size = 1;
for (SizeT i = 0; i < ndims; i++) {
this->size *= shape[i];
}
for (SizeT axis = 0; axis < ndims; axis++) indices[axis] = 0;
nth = 0;
}
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides,
ndarray->data, indices);
}
bool has_next() { return nth < size; }
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
indices[axis]++;
if (indices[axis] >= shape[axis]) {
indices[axis] = 0;
// TODO: Can be optimized with backstrides.
element -= strides[axis] * (shape[axis] - 1);
} else {
element += strides[axis];
break;
}
}
nth++;
}
};
} // namespace
extern "C" {
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray,
int32_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
void __nac3_nditer_initialize64(NDIter<int64_t>* iter,
NDArray<int64_t>* ndarray, int64_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
bool __nac3_nditer_has_next(NDIter<int32_t>* iter) { return iter->has_next(); }
bool __nac3_nditer_has_next64(NDIter<int64_t>* iter) { return iter->has_next(); }
void __nac3_nditer_next(NDIter<int32_t>* iter) { iter->next(); }
void __nac3_nditer_next64(NDIter<int64_t>* iter) { iter->next(); }
}

View File

@ -1,194 +0,0 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/iter.hpp>
// NOTE: Everything would be much easier and elegant if einsum is implemented.
namespace {
namespace ndarray {
namespace matmul {
/*
* In einsum notation, the output is the broadcasts performed by `np.einsum("...ij,...jk->...ik", a, b)`.
*
* Example:
* Suppose `a_shape == [99, 1, 97, 4, 2]`
* and `b_shape == [ 1, 98, 1, 2, 5]`,
*
* ...then `new_a_shape == [99, 98, 97, 4, 2]`,
* `new_b_shape == [99, 98, 97, 2, 5]`,
* and `dst_shape == [99, 98, 97, 4, 5]`.
* ^^^^^^^^^^ ^^^^
* (by broadcast) (4x2 @ 2x5 => 4x5)
*/
template <typename SizeT>
void calculate_shapes(SizeT a_ndims, SizeT* a_shape, SizeT b_ndims,
SizeT* b_shape, SizeT final_ndims, SizeT* new_a_shape,
SizeT* new_b_shape, SizeT* dst_shape) {
debug_assert(SizeT, a_ndims >= 2);
debug_assert(SizeT, b_ndims >= 2);
debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims);
const SizeT num_entries = 2;
ShapeEntry<SizeT> entries[num_entries] = {
{.ndims = a_ndims - 2, .shape = a_shape},
{.ndims = b_ndims - 2, .shape = b_shape}};
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries,
final_ndims - 2, new_a_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries,
final_ndims - 2, new_b_shape);
ndarray::broadcast::broadcast_shapes<SizeT>(num_entries, entries,
final_ndims - 2, dst_shape);
new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2];
new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1];
new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2];
new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1];
dst_shape[final_ndims - 2] = a_shape[a_ndims - 2];
dst_shape[final_ndims - 1] = b_shape[b_ndims - 1];
}
/**
* @brief Perform `np.matmul(a, b)` but the inputs are both rank >=2 matrices and `a.shape[:-2] == b.shape[:-2]`.
*
* The compatibility of `a` and `b` (for their `.shape[-2:]`) are asserted.
*
* Also see https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy-matmul.
*
* This function expects `dst_ndarray` to contain the following content when called:
* - `dst_ndarray->data` is allocated. Can be uninitialized.
* - `dst_ndarray->itemsize` is set to `sizeof(T)`.
* - `dst_ndarray->ndims` is set appropriately.
* - `dst_ndarray->shape` is set appropriately.
* - `dst_ndarray->strides` is ignored.
*
* Moreover, the shapes of `a_ndarray`, `b_ndarray`, and `dst_ndarray` **must be the same**. This implies
*/
template <typename SizeT, typename T>
void matmul_at_least_2d(NDArray<SizeT>* a_ndarray, NDArray<SizeT>* b_ndarray,
NDArray<SizeT>* dst_ndarray) {
// All inputs' ndims should be >= 2 and be the same.
debug_assert_eq(SizeT, a_ndarray->ndims, b_ndarray->ndims);
debug_assert_eq(SizeT, a_ndarray->ndims, dst_ndarray->ndims);
debug_assert(SizeT, a_ndarray->ndims >= 2);
debug_assert_eq(SizeT, a_ndarray->itemsize, sizeof(T));
debug_assert_eq(SizeT, b_ndarray->itemsize, sizeof(T));
debug_assert_eq(SizeT, dst_ndarray->itemsize, sizeof(T));
if (IRRT_DEBUG_ASSERT_BOOL) {
// Check that the shapes are the same.
for (SizeT i = 0; i < a_ndarray->ndims - 2; i++) {
if (dst_ndarray->shape[0] != a_ndarray->shape[0]) {
raise_debug_assert(
SizeT, "Bad shape. At axis {0}, a has {1}, dst has {2}", i,
a_ndarray->shape[i], dst_ndarray->shape[i]);
}
if (dst_ndarray->shape[0] != b_ndarray->shape[0]) {
raise_debug_assert(
SizeT, "Bad shape. At axis {0}, b has {1}, dst has {2}", i,
b_ndarray->shape[i], dst_ndarray->shape[i]);
}
}
}
// Number of dimensions dedicated to stacking
// e.g., [4, 6, 1, 2, 3]
// ^^^^^^^ count these
const SizeT u = a_ndarray->ndims - 2; // Alias
SizeT* a_mat_shape = a_ndarray->shape + u;
SizeT* b_mat_shape = b_ndarray->shape + u;
SizeT* dst_mat_shape = dst_ndarray->shape + u;
// Assert that dst_ndarray has the correct shape
debug_assert_eq(SizeT, dst_mat_shape[0], a_mat_shape[0]);
debug_assert_eq(SizeT, dst_mat_shape[1], b_mat_shape[1]);
// Check that a and b are compatible for matmul
if (a_mat_shape[1] != b_mat_shape[0]) {
// This is a custom error message. Different from NumPy.
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})",
a_mat_shape[1], b_mat_shape[0], NO_PARAM);
}
// Iterate through shape[:-2]. i.e,
// Given a = [5, 4, 3, m, p] and b = [5, 4, 3, p, n]. We iterate through [5, 4, 3].
SizeT* indices =
(SizeT*)__builtin_alloca(sizeof(SizeT) * dst_ndarray->ndims);
SizeT* mat_indices = indices + u;
NDIter<SizeT> iter;
iter.initialize(u, dst_ndarray->shape, dst_ndarray->strides,
dst_ndarray->data, indices);
for (; iter.has_next(); iter.next()) {
for (SizeT i = 0; i < dst_mat_shape[0]; i++) {
for (SizeT j = 0; j < dst_mat_shape[1]; j++) {
// `indices` is being reused to index into different ndarrays.
mat_indices[0] = i;
mat_indices[1] = j;
T* d = ndarray::basic::get_ptr<SizeT, T>(dst_ndarray, indices);
*d = 0;
for (SizeT k = 0; k < a_ndarray->shape[1]; k++) {
mat_indices[0] = i;
mat_indices[1] = k;
T* a =
ndarray::basic::get_ptr<SizeT, T>(a_ndarray, indices);
mat_indices[0] = k;
mat_indices[1] = j;
T* b =
ndarray::basic::get_ptr<SizeT, T>(b_ndarray, indices);
*d += (*a) * (*b);
}
}
}
}
}
} // namespace matmul
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::matmul;
void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, int32_t* a_shape,
int32_t b_ndims, int32_t* b_shape,
int32_t final_ndims,
int32_t* new_a_shape,
int32_t* new_b_shape,
int32_t* dst_shape) {
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims,
new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, int64_t* a_shape,
int64_t b_ndims, int64_t* b_shape,
int64_t final_ndims,
int64_t* new_a_shape,
int64_t* new_b_shape,
int64_t* dst_shape) {
calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims,
new_a_shape, new_b_shape, dst_shape);
}
void __nac3_ndarray_float64_matmul_at_least_2d(NDArray<int32_t>* a_ndarray,
NDArray<int32_t>* b_ndarray,
NDArray<int32_t>* dst_ndarray) {
matmul_at_least_2d<int32_t, double>(a_ndarray, b_ndarray, dst_ndarray);
}
void __nac3_ndarray_float64_matmul_at_least_2d64(
NDArray<int64_t>* a_ndarray, NDArray<int64_t>* b_ndarray,
NDArray<int64_t>* dst_ndarray) {
matmul_at_least_2d<int64_t, double>(a_ndarray, b_ndarray, dst_ndarray);
}
}

View File

@ -1,106 +0,0 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
namespace {
namespace ndarray {
namespace reshape {
/**
* @brief Perform assertions on and resolve unknown dimensions in `new_shape` in `np.reshape(<ndarray>, new_shape)`
*
* If `new_shape` indeed contains unknown dimensions (specified with `-1`, just like numpy), `new_shape` will be
* modified to contain the resolved dimension.
*
* To perform assertions on and resolve unknown dimensions in `new_shape`, we don't need the actual
* `<ndarray>` object itself, but only the `.size` of the `<ndarray>`.
*
* @param size The `.size` of `<ndarray>`
* @param new_ndims Number of elements in `new_shape`
* @param new_shape Target shape to reshape to
*/
template <typename SizeT>
void resolve_and_check_new_shape(SizeT size, SizeT new_ndims,
SizeT* new_shape) {
// Is there a -1 in `new_shape`?
bool neg1_exists = false;
// Location of -1, only initialized if `neg1_exists` is true
SizeT neg1_axis_i;
// The computed ndarray size of `new_shape`
SizeT new_size = 1;
for (SizeT axis_i = 0; axis_i < new_ndims; axis_i++) {
SizeT dim = new_shape[axis_i];
if (dim < 0) {
if (dim == -1) {
if (neg1_exists) {
// Multiple `-1` found. Throw an error.
raise_exception(SizeT, EXN_VALUE_ERROR,
"can only specify one unknown dimension",
NO_PARAM, NO_PARAM, NO_PARAM);
} else {
neg1_exists = true;
neg1_axis_i = axis_i;
}
} else {
// TODO: What? In `np.reshape` any negative dimensions is
// treated like its `-1`.
//
// Try running `np.zeros((3, 4)).reshape((-999, 2))`
//
// It is not documented by numpy.
// Throw an error for now...
raise_exception(
SizeT, EXN_VALUE_ERROR,
"Found non -1 negative dimension {0} on axis {1}", dim,
axis_i, NO_PARAM);
}
} else {
new_size *= dim;
}
}
bool can_reshape;
if (neg1_exists) {
// Let `x` be the unknown dimension
// Solve `x * <new_size> = <size>`
if (new_size == 0 && size == 0) {
// `x` has infinitely many solutions
can_reshape = false;
} else if (new_size == 0 && size != 0) {
// `x` has no solutions
can_reshape = false;
} else if (size % new_size != 0) {
// `x` has no integer solutions
can_reshape = false;
} else {
can_reshape = true;
new_shape[neg1_axis_i] = size / new_size; // Resolve dimension
}
} else {
can_reshape = (new_size == size);
}
if (!can_reshape) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"cannot reshape array of size {0} into given shape",
size, NO_PARAM, NO_PARAM);
}
}
} // namespace reshape
} // namespace ndarray
} // namespace
extern "C" {
void __nac3_ndarray_resolve_and_check_new_shape(int32_t size, int32_t new_ndims,
int32_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
void __nac3_ndarray_resolve_and_check_new_shape64(int64_t size,
int64_t new_ndims,
int64_t* new_shape) {
ndarray::reshape::resolve_and_check_new_shape(size, new_ndims, new_shape);
}
}

View File

@ -1,145 +0,0 @@
#pragma once
#include <irrt/int_defs.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/slice.hpp>
/*
* Notes on `np.transpose(<array>, <axes>)`
*
* TODO: `axes`, if specified, can actually contain negative indices,
* but it is not documented in numpy.
*
* Supporting it for now.
*/
namespace {
namespace ndarray {
namespace transpose {
/**
* @brief Do assertions on `<axes>` in `np.transpose(<array>, <axes>)`.
*
* Note that `np.transpose`'s `<axe>` argument is optional. If the argument
* is specified but the user, use this function to do assertions on it.
*
* @param ndims The number of dimensions of `<array>`
* @param num_axes Number of elements in `<axes>` as specified by the user.
* This should be equal to `ndims`. If not, a "ValueError: axes don't match array" is thrown.
* @param axes The user specified `<axes>`.
*/
template <typename SizeT>
void assert_transpose_axes(SizeT ndims, SizeT num_axes, const SizeT* axes) {
if (ndims != num_axes) {
raise_exception(SizeT, EXN_VALUE_ERROR, "axes don't match array",
NO_PARAM, NO_PARAM, NO_PARAM);
}
// TODO: Optimize this
bool* axe_specified = (bool*)__builtin_alloca(sizeof(bool) * ndims);
for (SizeT i = 0; i < ndims; i++) axe_specified[i] = false;
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = slice::resolve_index_in_length(ndims, axes[i]);
if (axis == slice::OUT_OF_BOUNDS) {
// TODO: numpy actually throws a `numpy.exceptions.AxisError`
raise_exception(
SizeT, EXN_VALUE_ERROR,
"axis {0} is out of bounds for array of dimension {1}", axis,
ndims, NO_PARAM);
}
if (axe_specified[axis]) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"repeated axis in transpose", NO_PARAM, NO_PARAM,
NO_PARAM);
}
axe_specified[axis] = true;
}
}
/**
* @brief Create a transpose view of `src_ndarray` and perform proper assertions.
*
* This function is very similar to doing `dst_ndarray = np.transpose(src_ndarray, <axes>)`.
* If `<axes>` is supposed to be `None`, caller can pass in a `nullptr` to `<axes>`.
*
* The transpose view created is returned by modifying `dst_ndarray`.
*
* The caller is responsible for setting up `dst_ndarray` before calling this function.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, must be equal to `src_ndarray->ndims`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data` (`dst_ndarray` is just a view to `src_ndarray`)
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`
* - `dst_ndarray->ndims` is unchanged
* - `dst_ndarray->shape` is updated according to how `np.transpose` works
* - `dst_ndarray->strides` is updated according to how `np.transpose` works
*
* @param src_ndarray The NDArray to build a transpose view on
* @param dst_ndarray The resulting NDArray after transpose. Further details in the comments above,
* @param num_axes Number of elements in axes. Unused if `axes` is nullptr.
* @param axes Axes permutation. Set it to `nullptr` if `<axes>` is `None`.
*/
template <typename SizeT>
void transpose(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray,
SizeT num_axes, const SizeT* axes) {
debug_assert_eq(SizeT, src_ndarray->ndims, dst_ndarray->ndims);
const auto ndims = src_ndarray->ndims;
if (axes != nullptr) assert_transpose_axes(ndims, num_axes, axes);
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Check out https://ajcr.net/stride-guide-part-2/ to see how `np.transpose` works behind the scenes.
if (axes == nullptr) {
// `np.transpose(<array>, axes=None)`
/*
* Minor note: `np.transpose(<array>, axes=None)` is equivalent to
* `np.transpose(<array>, axes=[N-1, N-2, ..., 0])` - basically it
* is reversing the order of strides and shape.
*
* This is a fast implementation to handle this special (but very common) case.
*/
for (SizeT axis = 0; axis < ndims; axis++) {
dst_ndarray->shape[axis] = src_ndarray->shape[ndims - axis - 1];
dst_ndarray->strides[axis] = src_ndarray->strides[ndims - axis - 1];
}
} else {
// `np.transpose(<array>, <axes>)`
// Permute strides and shape according to `axes`, while resolving negative indices in `axes`
for (SizeT axis = 0; axis < ndims; axis++) {
// `i` cannot be OUT_OF_BOUNDS because of assertions
SizeT i = slice::resolve_index_in_length(ndims, axes[axis]);
dst_ndarray->shape[axis] = src_ndarray->shape[i];
dst_ndarray->strides[axis] = src_ndarray->strides[i];
}
}
}
} // namespace transpose
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::transpose;
void __nac3_ndarray_transpose(const NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray, int32_t num_axes,
const int32_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
void __nac3_ndarray_transpose64(const NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray, int64_t num_axes,
const int64_t* axes) {
transpose(src_ndarray, dst_ndarray, num_axes, axes);
}
}

View File

@ -1,167 +1,28 @@
#pragma once #pragma once
#include <irrt/int_defs.hpp> #include "irrt/int_types.hpp"
#include <irrt/slice.hpp>
#include <irrt/util.hpp>
#include "exception.hpp" extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
// The type of an index or a value describing the length of a if (i < 0) {
// range/slice is always `int32_t`. i = len + i;
using SliceIndex = int32_t;
namespace {
/**
* @brief A Python-like slice with resolved indices.
*
* "Resolved indices" means that `start` and `stop` must be positive and are
* bound to a known length.
*/
struct Slice {
SliceIndex start;
SliceIndex stop;
SliceIndex step;
/**
* @brief Calculate and return the length / the number of the slice.
*
* If this were a Python range, this function would be `len(range(start, stop, step))`.
*/
SliceIndex len() {
SliceIndex diff = stop - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
} }
}; if (i < 0) {
return 0;
namespace slice { } else if (i > len) {
/** return len;
* @brief Resolve a slice index under a given length like Python indexing.
*
* In Python, if you have a `list` of length 100, `list[-1]` resolves to
* `list[99]`, so `resolve_index_in_length_clamped(100, -1)` returns `99`.
*
* If `length` is 0, 0 is returned for any value of `index`.
*
* If `index` is out of bounds, clamps the returned value between `0` and
* `length - 1` (inclusive).
*
*/
SliceIndex resolve_index_in_length_clamped(SliceIndex length,
SliceIndex index) {
if (index < 0) {
return max<SliceIndex>(length + index, 0);
} else {
return min<SliceIndex>(length, index);
} }
return i;
} }
const SliceIndex OUT_OF_BOUNDS = -1; SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
SliceIndex diff = end - start;
/** if (diff > 0 && step > 0) {
* @brief Like `resolve_index_in_length_clamped`, but returns `OUT_OF_BOUNDS` return ((diff - 1) / step) + 1;
* if `index` is out of bounds. } else if (diff < 0 && step < 0) {
*/ return ((diff + 1) / step) + 1;
SliceIndex resolve_index_in_length(SliceIndex length, SliceIndex index) {
SliceIndex resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) {
return resolved;
} else { } else {
return OUT_OF_BOUNDS; return 0;
} }
} }
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
struct UserSlice {
bool start_defined;
SliceIndex start;
bool stop_defined;
SliceIndex stop;
bool step_defined;
SliceIndex step;
UserSlice() { this->reset(); }
void reset() {
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(SliceIndex start) {
this->start_defined = true;
this->start = start;
}
void set_stop(SliceIndex stop) {
this->stop_defined = true;
this->stop = stop;
}
void set_step(SliceIndex step) {
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice.
*
* In Python, this would be `slice(start, stop, step).indices(length)`.
*
* @return A `Slice` with the resolved indices.
*/
Slice indices(SliceIndex length) {
Slice result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
if (start_defined) {
result.start =
slice::resolve_index_in_length_clamped(length, start);
} else {
result.start = step_is_negative ? length - 1 : 0;
}
if (stop_defined) {
result.stop = slice::resolve_index_in_length_clamped(length, stop);
} else {
result.stop = step_is_negative ? -1 : length;
}
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
template <typename SizeT>
Slice indices_checked(SliceIndex length) {
// TODO: Switch to `SizeT length`
if (length < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"length should not be negative, got {0}", length,
NO_PARAM, NO_PARAM);
}
if (this->step_defined && this->step == 0) {
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero",
NO_PARAM, NO_PARAM, NO_PARAM);
}
return this->indices(length);
}
};
} // namespace } // namespace

View File

@ -1,101 +0,0 @@
#pragma once
namespace {
template <typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template <typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
template <typename T>
bool arrays_match(int len, T* as, T* bs) {
for (int i = 0; i < len; i++) {
if (as[i] != bs[i]) return false;
}
return true;
}
namespace cstr_utils {
/**
* @brief Return true if `str` is empty.
*/
bool is_empty(const char* str) { return str[0] == '\0'; }
/**
* @brief Implementation of `strcmp()`
*/
int8_t compare(const char* a, const char* b) {
uint32_t i = 0;
while (true) {
if (a[i] < b[i]) {
return -1;
} else if (a[i] > b[i]) {
return 1;
} else {
if (a[i] == '\0') {
return 0;
} else {
i++;
}
}
}
}
/**
* @brief Return true two strings have the same content.
*/
int8_t equal(const char* a, const char* b) { return compare(a, b) == 0; }
/**
* @brief Implementation of `strlen()`.
*/
uint32_t length(const char* str) {
uint32_t length = 0;
while (*str != '\0') {
length++;
str++;
}
return length;
}
/**
* @brief Copy a null-terminated string to a buffer with limited size and guaranteed null-termination.
*
* `dst_max_size` must be greater than 0, otherwise this function has undefined behavior.
*
* This function attempts to copy everything from `src` from `dst`, and *always* null-terminates `dst`.
*
* If the size of `dst` is too small, the final byte (`dst[dst_max_size - 1]`) of `dst` will be set to
* the null terminator.
*
* @param src String to copy from.
* @param dst Buffer to copy string to.
* @param dst_max_size
* Number of bytes of this buffer, including the space needed for the null terminator.
* Must be greater than 0.
* @return If `dst` is too small to contain everything in `src`.
*/
bool copy(const char* src, char* dst, uint32_t dst_max_size) {
for (uint32_t i = 0; i < dst_max_size; i++) {
bool is_last = i + 1 == dst_max_size;
if (is_last && src[i] != '\0') {
dst[i] = '\0';
return false;
}
if (src[i] == '\0') {
dst[i] = '\0';
return true;
}
dst[i] = src[i];
}
__builtin_unreachable();
}
} // namespace cstr_utils
} // namespace

View File

@ -1,24 +0,0 @@
#pragma once
#ifdef IRRT_DEBUG
#define IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#include <irrt/core.hpp>
#include <irrt/debug.hpp>
#include <irrt/exception.hpp>
#include <irrt/int_defs.hpp>
#include <irrt/list.hpp>
#include <irrt/ndarray/array.hpp>
#include <irrt/ndarray/basic.hpp>
#include <irrt/ndarray/broadcast.hpp>
#include <irrt/ndarray/def.hpp>
#include <irrt/ndarray/indexing.hpp>
#include <irrt/ndarray/iter.hpp>
#include <irrt/ndarray/product.hpp>
#include <irrt/ndarray/reshape.hpp>
#include <irrt/ndarray/transpose.hpp>
#include <irrt/util.hpp>

View File

@ -1,25 +0,0 @@
// This file will be compiled like a real C++ program,
// and we do have the luxury to use the standard libraries.
// That is if the nix flakes do not have issues... especially on msys2...
#include <cstdint>
#include <cstdio>
#include <cstdlib>
// Special macro to inform `#include <irrt/*>` that we are testing.
#define IRRT_TESTING
// Note that failure unit tests are not supported.
#include <test/test_core.hpp>
#include <test/test_ndarray_basic.hpp>
#include <test/test_ndarray_broadcast.hpp>
#include <test/test_ndarray_indexing.hpp>
int main() {
test::core::run();
test::ndarray_basic::run();
test::ndarray_indexing::run();
test::ndarray_broadcast::run();
return 0;
}

View File

@ -1,11 +0,0 @@
#pragma once
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <irrt_everything.hpp>
#include <test/util.hpp>
/*
Include this header for every test_*.cpp
*/

View File

@ -1,16 +0,0 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace core {
void test_int_exp() {
BEGIN_TEST();
assert_values_match(125L, __nac3_int_exp_impl<int64_t>(5, 3));
assert_values_match(3125L, __nac3_int_exp_impl<int64_t>(5, 5));
}
void run() { test_int_exp(); }
} // namespace core
} // namespace test

View File

@ -1,30 +0,0 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_basic {
void test_calc_size_from_shape_normal() {
// Test shapes with normal values
BEGIN_TEST();
int64_t shape[4] = {2, 3, 5, 7};
assert_values_match(
210L, ndarray::basic::util::calc_size_from_shape<int64_t>(4, shape));
}
void test_calc_size_from_shape_has_zero() {
// Test shapes with 0 in them
BEGIN_TEST();
int64_t shape[4] = {2, 0, 5, 7};
assert_values_match(
0L, ndarray::basic::util::calc_size_from_shape<int64_t>(4, shape));
}
void run() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
}
} // namespace ndarray_basic
} // namespace test

View File

@ -1,127 +0,0 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_broadcast {
void test_can_broadcast_shape() {
BEGIN_TEST();
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 5, (int32_t[]){1, 1, 1, 1, 3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 2, (int32_t[]){3, 1}));
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){3}, 1, (int32_t[]){3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){1}, 1, (int32_t[]){3}));
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){1}, 1, (int32_t[]){1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 3, (int32_t[]){256, 1, 3}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){3}));
assert_values_match(false,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){2}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
3, (int32_t[]){256, 256, 3}, 1, (int32_t[]){1}));
// In cases when the shapes contain zero(es)
assert_values_match(true, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){0}, 1, (int32_t[]){1}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
1, (int32_t[]){0}, 1, (int32_t[]){2}));
assert_values_match(true,
ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 1, (int32_t[]){1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 1, 1, 1}));
assert_values_match(
true, ndarray::broadcast::util::can_broadcast_shape_to(
4, (int32_t[]){0, 4, 0, 0}, 4, (int32_t[]){1, 4, 1, 1}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 3}));
assert_values_match(false, ndarray::broadcast::util::can_broadcast_shape_to(
2, (int32_t[]){4, 3}, 2, (int32_t[]){0, 0}));
}
void test_ndarray_broadcast() {
/*
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
# >>> [[19.9 29.9 39.9 49.9]]
#
# array = np.broadcast_to(array, (2, 3, 4))
# >>> [[[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]
# >>> [[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]]
#
# assery array.strides == (0, 0, 8)
*/
BEGIN_TEST();
double in_data[4] = {19.9, 29.9, 39.9, 49.9};
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = {1, 4};
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {.data = (uint8_t*)in_data,
.itemsize = sizeof(double),
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides};
ndarray::basic::set_strides_by_shape(&ndarray);
const int32_t dst_ndims = 3;
int32_t dst_shape[dst_ndims] = {2, 3, 4};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {
.ndims = dst_ndims, .shape = dst_shape, .strides = dst_strides};
ndarray::broadcast::broadcast_to(&ndarray, &dst_ndarray);
assert_arrays_match(dst_ndims, ((int32_t[]){0, 0, 8}), dst_ndarray.strides);
assert_values_match(19.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 0}))));
assert_values_match(29.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 1}))));
assert_values_match(39.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 2}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 0, 3}))));
assert_values_match(19.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 0}))));
assert_values_match(29.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 1}))));
assert_values_match(39.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 2}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){0, 1, 3}))));
assert_values_match(49.9,
*((double*)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, ((int32_t[]){1, 2, 3}))));
}
void run() {
test_can_broadcast_shape();
test_ndarray_broadcast();
}
} // namespace ndarray_broadcast
} // namespace test

View File

@ -1,165 +0,0 @@
#pragma once
#include <test/includes.hpp>
namespace test {
namespace ndarray_indexing {
void test_normal_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int64_t src_itemsize = sizeof(double);
const int64_t src_ndims = 2;
int64_t src_shape[src_ndims] = {3, 4};
int64_t src_strides[src_ndims] = {};
NDArray<int64_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int64_t dst_ndims = 2;
int64_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int64_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int64_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[-2::, 1::2]`
UserSlice subscript_1;
subscript_1.set_start(-2);
UserSlice subscript_2;
subscript_2.set_start(1);
subscript_2.set_step(2);
const int64_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray);
int64_t expected_shape[dst_ndims] = {2, 2};
int64_t expected_strides[dst_ndims] = {32, 16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
// dst_ndarray[0, 0]
assert_values_match(5.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){0, 0})));
// dst_ndarray[0, 1]
assert_values_match(7.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){0, 1})));
// dst_ndarray[1, 0]
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){1, 0})));
// dst_ndarray[1, 1]
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){1, 1})));
}
void test_normal_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[12] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0,
6.0, 7.0, 8.0, 9.0, 10.0, 11.0};
int64_t src_itemsize = sizeof(double);
const int64_t src_ndims = 2;
int64_t src_shape[src_ndims] = {3, 4};
int64_t src_strides[src_ndims] = {};
NDArray<int64_t> src_ndarray = {.data = (uint8_t *)src_data,
.itemsize = src_itemsize,
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int64_t dst_ndims = 1;
int64_t dst_shape[dst_ndims] = {999}; // Empty values
int64_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int64_t> dst_ndarray = {.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides};
// Create the subscripts in `ndarray[2, ::-2]`
int64_t subscript_1 = 2;
UserSlice subscript_2;
subscript_2.set_step(-2);
const int64_t num_indexes = 2;
NDIndex indexes[num_indexes] = {
{.type = ND_INDEX_TYPE_SINGLE_ELEMENT, .data = (uint8_t *)&subscript_1},
{.type = ND_INDEX_TYPE_SLICE, .data = (uint8_t *)&subscript_2}};
ndarray::indexing::index(num_indexes, indexes, &src_ndarray, &dst_ndarray);
int64_t expected_shape[dst_ndims] = {2};
int64_t expected_strides[dst_ndims] = {-16};
assert_arrays_match(dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match(dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match(11.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){0})));
assert_values_match(9.0,
*((double *)ndarray::basic::get_pelement_by_indices(
&dst_ndarray, (int64_t[dst_ndims]){1})));
}
void run() {
test_normal_1();
test_normal_2();
}
} // namespace ndarray_indexing
} // namespace test

View File

@ -1,131 +0,0 @@
#pragma once
#include <cstdio>
#include <cstdlib>
template <class T>
void print_value(const T& value);
template <>
void print_value(const bool& value) {
printf("%s", value ? "true" : "false");
}
template <>
void print_value(const int8_t& value) {
printf("%d", value);
}
template <>
void print_value(const int32_t& value) {
printf("%d", value);
}
template <>
void print_value(const int64_t& value) {
printf("%d", value);
}
template <>
void print_value(const uint8_t& value) {
printf("%u", value);
}
template <>
void print_value(const uint32_t& value) {
printf("%u", value);
}
template <>
void print_value(const uint64_t& value) {
printf("%d", value);
}
template <>
void print_value(const float& value) {
printf("%f", value);
}
template <>
void print_value(const double& value) {
printf("%f", value);
}
void __begin_test(const char* function_name, const char* file, int line) {
printf("######### Running %s @ %s:%d\n", function_name, file, line);
}
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
void test_fail() {
printf("[!] Test failed. Exiting with status code 1.\n");
exit(1);
}
template <typename T>
void debug_print_array(int len, const T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
print_value(as[i]);
}
printf("]");
}
void print_assertion_passed(const char* file, int line) {
printf("[*] Assertion passed on %s:%d\n", file, line);
}
void print_assertion_failed(const char* file, int line) {
printf("[!] Assertion failed on %s:%d\n", file, line);
}
void __assert_true(const char* file, int line, bool cond) {
if (cond) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
test_fail();
}
}
#define assert_true(cond) __assert_true(__FILE__, __LINE__, cond)
template <typename T>
void __assert_arrays_match(const char* file, int line, int len,
const T* expected, const T* got) {
if (arrays_match(len, expected, got)) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
printf("Expect = ");
debug_print_array(len, expected);
printf("\n");
printf(" Got = ");
debug_print_array(len, got);
printf("\n");
test_fail();
}
}
#define assert_arrays_match(len, expected, got) \
__assert_arrays_match(__FILE__, __LINE__, len, expected, got)
template <typename T>
void __assert_values_match(const char* file, int line, T expected, T got) {
if (expected == got) {
print_assertion_passed(file, line);
} else {
print_assertion_failed(file, line);
printf("Expect = ");
print_value(expected);
printf("\n");
printf(" Got = ");
print_value(got);
printf("\n");
test_fail();
}
}
#define assert_values_match(expected, got) \
__assert_values_match(__FILE__, __LINE__, expected, got)

View File

@ -0,0 +1,21 @@
[package]
name = "nac3core_derive"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[[test]]
name = "structfields_tests"
path = "tests/structfields_test.rs"
[dev-dependencies]
nac3core = { path = ".." }
trybuild = { version = "1.0", features = ["diff"] }
[dependencies]
proc-macro2 = "1.0"
proc-macro-error = "1.0"
syn = "2.0"
quote = "1.0"

View File

@ -0,0 +1,320 @@
use proc_macro::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use quote::quote;
use syn::{
parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall,
ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath,
};
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
///
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
/// `expected_ty_name`, otherwise returns [`None`].
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
return None;
};
let segments = &path.segments;
if segments.len() != 1 {
return None;
};
let segment = segments.iter().next().unwrap();
if segment.ident != expected_ty_name {
return None;
}
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
return Some(Vec::new());
};
let args = &path_args.args;
Some(args.iter().cloned().collect::<Vec<_>>())
}
/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`].
fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option<Ident> {
path.require_ident()
.ok()
.filter(|ident| target_idents.iter().any(|target| ident == target))
.map(|ident| Ident::new(replacement, ident.span()))
}
/// Extracts the left-hand side of a dot-expression.
fn extract_dot_operand(expr: &Expr) -> Option<&Expr> {
match expr {
Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) => Some(operand),
_ => None,
}
}
/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the
/// replacement is performed.
///
/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`.
fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> {
if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) = expr
{
return if extract_dot_operand(operand).is_some() {
if replace_top_level_receiver(operand, ident).is_some() {
Some(expr)
} else {
None
}
} else {
*operand = Box::new(Expr::Path(ExprPath {
attrs: Vec::default(),
qself: None,
path: ident.into(),
}));
Some(expr)
};
}
None
}
/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all
/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`].
///
/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will
/// return `vec![c, b, a]`.
fn iter_dot_operands(expr: &Expr) -> impl Iterator<Item = &Expr> {
let mut o = extract_dot_operand(expr);
std::iter::from_fn(move || {
let this = o;
o = o.as_ref().and_then(|o| extract_dot_operand(o));
this
})
}
/// Normalizes a value expression for use when creating an instance of this structure, returning a
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
match &expr {
Expr::Path(ExprPath { qself: None, path, .. }) => {
if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") {
quote! { #ident }
} else {
abort!(
path,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
Expr::Call(_) => {
quote! { ctx.#expr }
}
Expr::MethodCall(_) => {
let base_receiver = iter_dot_operands(expr).last();
match base_receiver {
// `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() =>
{
let ident =
map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() =>
{
let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// No reserved identifier prefix -> Prepend `ctx.` to the entire expression
_ => quote! { ctx.#expr },
}
}
_ => {
abort!(
expr,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
}
/// Derives an implementation of `codegen::types::structure::StructFields`.
///
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
///
/// # Prerequisites
///
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
/// `StructFields`.
///
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
/// with either `StructField` or [`PhantomData`] types.
///
/// # Attributes for [`StructFields`]
///
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
/// accepts one of the following:
///
/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`).
/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`.
/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent
/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g.
/// `usize.array_type(3)`.
///
/// # Example
///
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
///
/// ```rust,ignore
/// use nac3core::{
/// codegen::types::structure::StructField,
/// inkwell::{
/// values::{IntValue, PointerValue},
/// AddressSpace,
/// },
/// };
/// use nac3core_derive::StructFields;
///
/// // All classes that implement StructFields must also implement Eq and Copy
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
/// pub struct SliceValue<'ctx> {
/// // Declares ptr have a value type of i8*
/// //
/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)`
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
///
/// // Declares len have a value type of usize, depending on the target compilation platform
/// #[value_type(usize)]
/// len: StructField<'ctx, IntValue<'ctx>>,
/// }
/// ```
#[proc_macro_derive(StructFields, attributes(value_type))]
#[proc_macro_error]
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::DeriveInput);
let ident = &input.ident;
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
abort!(input, "Only structs with named fields are supported");
};
if let Err(err_span) =
fields
.iter()
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
{
abort!(err_span, "Only structs with named fields are supported");
};
// Check if struct<'ctx>
if input.generics.params.len() != 1 {
abort!(input.generics, "Expected exactly 1 generic parameter")
}
let phantom_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
.map(|field| field.ident.as_ref().unwrap())
.cloned()
.collect::<Vec<_>>();
let field_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
.map(|field| {
let ident = field.ident.as_ref().unwrap();
let ty = &field.ty;
let Some(_) = extract_generic_args("StructField", ty) else {
abort!(field, "Only StructField and PhantomData are allowed")
};
let attrs = &field.attrs;
let Some(value_type_attr) =
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
else {
abort!(field, "Expected #[value_type(...)] attribute for field");
};
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
};
let value_expr_toks = normalize_value_expr(&value_type_expr);
(ident.clone(), value_expr_toks)
})
.collect::<Vec<_>>();
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
let phantoms_create = phantom_info
.iter()
.map(|id| quote! { #id: ::std::marker::PhantomData })
.collect::<Vec<_>>();
let fields_create = field_info
.iter()
.map(|(id, ty)| {
let id_lit = LitStr::new(&id.to_string(), id.span());
quote! {
#id: ::nac3core::codegen::types::structure::StructField::create(
&mut counter,
#id_lit,
#ty,
)
}
})
.collect::<Vec<_>>();
// `.into()` impl of `StructField` for `StructFields::to_vec`
let fields_into =
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
let impl_block = quote! {
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
#ident {
#(#fields_create),*
#(#phantoms_create),*
}
}
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
vec![
#(#fields_into),*
]
}
}
};
impl_block.into()
}

View File

@ -0,0 +1,9 @@
use nac3core_derive::StructFields;
use std::marker::PhantomData;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct EmptyValue<'ctx> {
_phantom: PhantomData<&'ctx ()>,
}
fn main() {}

View File

@ -0,0 +1,20 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayValue<'ctx> {
#[value_type(usize)]
ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
data: StructField<'ctx, PointerValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(context.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(size_t)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,10 @@
#[test]
fn test_parse_empty() {
let t = trybuild::TestCases::new();
t.pass("tests/structfields_empty.rs");
t.pass("tests/structfields_slice.rs");
t.pass("tests/structfields_slice_ctx.rs");
t.pass("tests/structfields_slice_context.rs");
t.pass("tests/structfields_slice_sizet.rs");
t.pass("tests/structfields_ndarray.rs");
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,9 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::DefinitionId, toplevel::DefinitionId,
@ -9,10 +15,6 @@ use crate::{
}, },
}; };
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use std::collections::HashMap;
pub struct ConcreteTypeStore { pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>, store: Vec<ConcreteTypeEnum>,
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +1,10 @@
use inkwell::attributes::{Attribute, AttributeLoc}; use inkwell::{
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}; attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use itertools::Either; use itertools::Either;
use crate::codegen::CodeGenContext; use super::CodeGenContext;
/// Macro to generate extern function /// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue` /// Both function return type and function parameter type are `FloatValue`
@ -13,11 +15,11 @@ use crate::codegen::CodeGenContext;
/// * `$extern_fn:literal`: Name of underlying extern function /// * `$extern_fn:literal`: Name of underlying extern function
/// ///
/// Optional Arguments: /// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function /// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly" /// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
/// These will be used unless other attributes are specified /// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function /// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue` /// The data type of these operands will be set to `FloatValue`
/// ///
macro_rules! generate_extern_fn { macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => { ("unary", $fn_name:ident, $extern_fn:literal) => {

View File

@ -1,16 +1,18 @@
use crate::{
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
use inkwell::{ use inkwell::{
context::Context, context::Context,
types::{BasicTypeEnum, IntType}, types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
}; };
use nac3parser::ast::{Expr, Stmt, StrRef}; use nac3parser::ast::{Expr, Stmt, StrRef};
use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
pub trait CodeGenerator { pub trait CodeGenerator {
/// Return the module name for the code generator. /// Return the module name for the code generator.
fn get_name(&self) -> &str; fn get_name(&self) -> &str;
@ -57,6 +59,7 @@ pub trait CodeGenerator {
/// - fun: Function signature, definition ID and the substitution key. /// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the /// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method. /// function is a class method.
///
/// Note that this function should check if the function is generated in another thread (due to /// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example. /// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx>( fn gen_func_instance<'ctx>(

View File

@ -0,0 +1,162 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use super::calculate_len_for_slice_range;
use crate::codegen::{
macros::codegen_unreachable,
values::{ArrayLikeValue, ListValue},
CodeGenContext, CodeGenerator,
};
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr =
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr =
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let dest_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
};
// update length
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb);
}

View File

@ -0,0 +1,152 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
IntPredicate,
};
use itertools::Either;
use crate::codegen::{
macros::codegen_unreachable,
{CodeGenContext, CodeGenerator},
};
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
signed: bool,
) -> IntValue<'ctx> {
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => codegen_unreachable!(ctx),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,384 @@
use inkwell::{
types::IntType,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use crate::codegen::{
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
values::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, NDArrayValue, TypedArrayLikeAccessor,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
};
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.shape().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -0,0 +1,76 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, IntValue},
IntPredicate,
};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
};
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None);
};
Ok(Some(
ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(),
))
}
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,26 +0,0 @@
#[cfg(test)]
mod tests {
use std::{path::Path, process::Command};
#[test]
fn run_irrt_test() {
assert!(
cfg!(feature = "test"),
"Please do `cargo test -F test` to compile `irrt_test.out` and run test"
);
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
if !output.status.success() {
eprintln!("irrt_test failed with status {}:", output.status);
eprintln!("====== stdout ======");
eprintln!("{}", String::from_utf8(output.stdout).unwrap());
eprintln!("====== stderr ======");
eprintln!("{}", String::from_utf8(output.stderr).unwrap());
eprintln!("====================");
panic!("irrt_test failed");
}
}
}

View File

@ -1,12 +1,14 @@
use crate::codegen::CodeGenContext; use inkwell::{
use inkwell::context::Context; context::Context,
use inkwell::intrinsics::Intrinsic; intrinsics::Intrinsic,
use inkwell::types::AnyTypeEnum::IntType; types::{AnyTypeEnum::IntType, FloatType},
use inkwell::types::FloatType; values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}; AddressSpace,
use inkwell::AddressSpace; };
use itertools::Either; use itertools::Either;
use super::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic /// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions. /// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
@ -183,7 +185,7 @@ pub fn call_memcpy_generic<'ctx>(
dest dest
} else { } else {
ctx.builder ctx.builder
.build_bitcast(dest, llvm_p0i8, "") .build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap() .unwrap()
}; };
@ -191,7 +193,7 @@ pub fn call_memcpy_generic<'ctx>(
src src
} else { } else {
ctx.builder ctx.builder
.build_bitcast(src, llvm_p0i8, "") .build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap() .unwrap()
}; };
@ -205,8 +207,9 @@ pub fn call_memcpy_generic<'ctx>(
/// * `$ctx:ident`: Reference to the current Code Generation Context /// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>) /// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function /// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type) /// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
/// Use `BasicValueEnum::into_int_value` for Integer return type and `BasicValueEnum::into_float_value` for Float return type /// Use `BasicValueEnum::into_int_value` for Integer return type and
/// `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand /// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands /// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body { macro_rules! generate_llvm_intrinsic_fn_body {
@ -222,8 +225,8 @@ macro_rules! generate_llvm_intrinsic_fn_body {
/// Arguments: /// Arguments:
/// * `float/int`: Indicates the return and argument type of the function /// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated /// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function /// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil" /// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations /// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations /// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn { macro_rules! generate_llvm_intrinsic_fn {

View File

@ -1,12 +1,12 @@
use crate::{ use std::{
codegen::classes::{ListType, ProxyType, RangeType}, collections::{HashMap, HashSet},
symbol_resolver::{StaticValue, SymbolResolver}, sync::{
toplevel::{helper::PrimDef, TopLevelContext, TopLevelDef}, atomic::{AtomicBool, Ordering},
typecheck::{ Arc,
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
}, },
thread,
}; };
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -24,36 +24,52 @@ use inkwell::{
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use model::*;
use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread;
use structure::{CSlice, Exception, NDArray};
pub mod classes; use nac3parser::ast::{Location, Stmt, StrRef};
use crate::{
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
};
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
use types::{ListType, NDArrayType, ProxyType, RangeType};
pub mod builtin_fns;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;
pub mod extern_fns; pub mod extern_fns;
mod generator; mod generator;
pub mod irrt; pub mod irrt;
pub mod llvm_intrinsics; pub mod llvm_intrinsics;
pub mod model;
pub mod numpy; pub mod numpy;
pub mod numpy_new;
pub mod object;
pub mod stmt; pub mod stmt;
pub mod structure; pub mod types;
pub mod values;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; mod macros {
pub use generator::{CodeGenerator, DefaultCodeGenerator}; /// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
/// its first argument to provide Python source information to indicate the codegen location
/// causing the assertion.
macro_rules! codegen_unreachable {
($ctx:expr $(,)?) => {
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
};
($ctx:expr, $($arg:tt)*) => {
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
};
}
pub(crate) use codegen_unreachable;
}
#[derive(Default)] #[derive(Default)]
pub struct StaticValueStore { pub struct StaticValueStore {
@ -173,11 +189,11 @@ pub struct CodeGenContext<'ctx, 'a> {
pub registry: &'a WorkerRegistry, pub registry: &'a WorkerRegistry,
/// Cache for constant strings. /// Cache for constant strings.
pub const_strings: HashMap<String, Struct<'ctx, CSlice>>, pub const_strings: HashMap<String, BasicValueEnum<'ctx>>,
/// [`BasicBlock`] containing all `alloca` statements for the current function. /// [`BasicBlock`] containing all `alloca` statements for the current function.
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
pub exception_val: Option<Ptr<'ctx, StructModel<Exception>>>, pub exception_val: Option<PointerValue<'ctx>>,
/// The header and exit basic blocks of a loop in this context. See /// The header and exit basic blocks of a loop in this context. See
/// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology. /// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology.
@ -494,8 +510,12 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let pndarray_model = PtrModel(StructModel(NDArray)); let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
pndarray_model.get_type(generator, ctx).as_basic_type_enum() let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
} }
_ => unreachable!( _ => unreachable!(
@ -581,11 +601,11 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI // If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
// consistency. // consistency.
return if unifier.unioned(ty, primitives.bool) { if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into() ctx.bool_type().into()
} else { } else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
}; }
} }
/// Whether `sret` is needed for a return value with type `ty`. /// Whether `sret` is needed for a return value with type `ty`.
@ -701,19 +721,43 @@ pub fn gen_func_impl<
..primitives ..primitives
}; };
let cslice_model = StructModel(CSlice); let mut type_cache: HashMap<_, _> = [
let pexn_model = PtrModel(StructModel(Exception));
let mut type_cache: HashMap<_, BasicTypeEnum<'ctx>> = [
(primitives.int32, context.i32_type().into()), (primitives.int32, context.i32_type().into()),
(primitives.int64, context.i64_type().into()), (primitives.int64, context.i64_type().into()),
(primitives.uint32, context.i32_type().into()), (primitives.uint32, context.i32_type().into()),
(primitives.uint64, context.i64_type().into()), (primitives.uint64, context.i64_type().into()),
(primitives.float, context.f64_type().into()), (primitives.float, context.f64_type().into()),
(primitives.bool, context.i8_type().into()), (primitives.bool, context.i8_type().into()),
(primitives.str, cslice_model.get_type(generator, context).into()), (primitives.str, {
let name = "str";
match module.get_struct_type(name) {
None => {
let str_type = context.opaque_struct_type("str");
let fields = [
context.i8_type().ptr_type(AddressSpace::default()).into(),
generator.get_size_type(context).into(),
];
str_type.set_body(&fields, false);
str_type.into()
}
Some(t) => t.as_basic_type_enum(),
}
}),
(primitives.range, RangeType::new(context).as_base_type().into()), (primitives.range, RangeType::new(context).as_base_type().into()),
(primitives.exception, pexn_model.get_type(generator, context).into()), (primitives.exception, {
let name = "Exception";
if let Some(t) = module.get_struct_type(name) {
t.ptr_type(AddressSpace::default()).as_basic_type_enum()
} else {
let exception = context.opaque_struct_type("Exception");
let int32 = context.i32_type().into();
let int64 = context.i64_type().into();
let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum();
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
}
}),
] ]
.iter() .iter()
.copied() .copied()
@ -809,10 +853,9 @@ pub fn gen_func_impl<
builder.position_at_end(init_bb); builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body"); let body_bb = context.append_basic_block(fn_val, "body");
// Store non-vararg argument values into local variables
let mut var_assignment = HashMap::new(); let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret); let offset = u32::from(has_sret);
// Store non-vararg argument values into local variables
for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) { for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type( let local_type = get_llvm_type(

View File

@ -1,40 +0,0 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
};
use crate::codegen::CodeGenerator;
use super::*;
#[derive(Debug, Clone, Copy)]
pub struct AnyModel<'ctx>(pub BasicTypeEnum<'ctx>);
pub type Anything<'ctx> = Instance<'ctx, AnyModel<'ctx>>;
impl<'ctx> Model<'ctx> for AnyModel<'ctx> {
type Value = BasicValueEnum<'ctx>;
type Type = BasicTypeEnum<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> Self::Type {
self.0
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
_generator: &mut G,
_ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
if ty == self.0 {
Ok(())
} else {
Err(ModelError(format!("Expecting {}, but got {}", self.0, ty)))
}
}
}

View File

@ -1,122 +0,0 @@
use inkwell::{
context::Context,
types::{ArrayType, BasicType, BasicTypeEnum},
values::ArrayValue,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
/// A Model for an [`ArrayType`].
#[derive(Debug, Clone, Copy)]
pub struct ArrayModel<Element> {
pub len: u32,
pub element: Element,
}
pub type Array<'ctx, Element> = Instance<'ctx, ArrayModel<Element>>;
impl<'ctx, Element: Model<'ctx>> Model<'ctx> for ArrayModel<Element> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.element.get_type(generator, ctx).array_type(self.len)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let BasicTypeEnum::ArrayType(ty) = ty else {
return Err(ModelError(format!("Expecting ArrayType, but got {ty:?}")));
};
if ty.len() != self.len {
return Err(ModelError(format!(
"Expecting ArrayType with size {}, but got an ArrayType with size {}",
ty.len(),
self.len
)));
}
self.element
.check_type(generator, ctx, ty.get_element_type())
.map_err(|err| err.under_context("an ArrayType"))?;
Ok(())
}
}
impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, ArrayModel<Element>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn at<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: u32,
name: &str,
) -> Ptr<'ctx, Element> {
assert!(i < self.model.0.len);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap()
}
}
/// Like [`ArrayModel`] but length is strongly-typed.
#[derive(Debug, Clone, Copy, Default)]
pub struct NArrayModel<const LEN: u32, Element>(pub Element);
pub type NArray<'ctx, const LEN: u32, Element> = Instance<'ctx, NArrayModel<LEN, Element>>;
impl<'ctx, const LEN: u32, Element: Model<'ctx>> NArrayModel<LEN, Element> {
/// Forget the `LEN` constant generic and get an [`ArrayModel`] with the same length.
pub fn forget_len(&self) -> ArrayModel<Element> {
ArrayModel { element: self.0, len: LEN }
}
}
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Model<'ctx> for NArrayModel<LEN, Element> {
type Value = ArrayValue<'ctx>;
type Type = ArrayType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
// Convenient implementation
self.forget_len().get_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
// Convenient implementation
self.forget_len().check_type(generator, ctx, ty)
}
}
impl<'ctx, const LEN: u32, Element: Model<'ctx>> Ptr<'ctx, NArrayModel<LEN, Element>> {
/// Get the pointer to the `i`-th (0-based) array element.
pub fn at_const<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
i: u32,
name: &str,
) -> Ptr<'ctx, Element> {
assert!(i < LEN);
let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()
}
}

View File

@ -1,123 +0,0 @@
use std::fmt;
use inkwell::{context::Context, types::*, values::*};
use super::*;
use crate::codegen::{CodeGenContext, CodeGenerator};
#[derive(Debug, Clone)]
pub struct ModelError(pub String);
impl ModelError {
pub(super) fn under_context(mut self, context: &str) -> Self {
self.0.push_str(" ... in ");
self.0.push_str(context);
self
}
}
pub trait Model<'ctx>: fmt::Debug + Clone + Copy {
type Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>>;
type Type: BasicType<'ctx>;
/// Return the [`BasicType`] of this model.
#[must_use]
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type;
/// Check if a [`BasicType`] is the same type of this model.
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError>;
/// Create an instance from a value with [`Instance::model`] being this model.
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent.
#[must_use]
fn believe_value(&self, value: Self::Value) -> Instance<'ctx, Self> {
Instance { model: *self, value }
}
/// Check if a [`BasicValue`]'s type is equivalent to the type of this model.
/// Wrap it into an [`Instance`] if it is.
fn check_value<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
value: V,
) -> Result<Instance<'ctx, Self>, ModelError> {
let value = value.as_basic_value_enum();
self.check_type(generator, ctx, value.get_type())
.map_err(|err| err.under_context(format!("the value {value:?}").as_str()))?;
let Ok(value) = Self::Value::try_from(value) else {
unreachable!("check_type() has bad implementation")
};
Ok(self.believe_value(value))
}
// Allocate a value on the stack and return its pointer.
fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
name: &str,
) -> Ptr<'ctx, Self> {
let pmodel = PtrModel(*self);
let p = ctx.builder.build_alloca(self.get_type(generator, ctx.ctx), name).unwrap();
pmodel.believe_value(p)
}
// Allocate an array on the stack and return its pointer.
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Self> {
let pmodel = PtrModel(*self);
let p =
ctx.builder.build_array_alloca(self.get_type(generator, ctx.ctx), len, name).unwrap();
pmodel.believe_value(p)
}
fn var_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> Result<Ptr<'ctx, Self>, String> {
let pmodel = PtrModel(*self);
let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_var_alloc(ctx, ty, name)?;
Ok(pmodel.believe_value(p))
}
fn array_var_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
len: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<Ptr<'ctx, Self>, String> {
// TODO: Remove ArraySliceValue
let pmodel = PtrModel(*self);
let ty = self.get_type(generator, ctx.ctx).as_basic_type_enum();
let p = generator.gen_array_var_alloc(ctx, ty, len, name)?;
Ok(pmodel.believe_value(PointerValue::from(p)))
}
}
#[derive(Debug, Clone, Copy)]
pub struct Instance<'ctx, M: Model<'ctx>> {
/// The model of this instance.
pub model: M,
/// The value of this instance.
///
/// Caller must make sure the type of `value` and the type of this `model` are equivalent,
/// down to having the same [`IntType::get_bit_width`] in case of [`IntType`] for example.
pub value: M::Value,
}

View File

@ -1,88 +0,0 @@
use std::fmt;
use inkwell::{context::Context, types::FloatType, values::FloatValue};
use crate::codegen::CodeGenerator;
use super::*;
pub trait FloatKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Float32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Float64;
impl<'ctx> FloatKind<'ctx> for Float32 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f32_type()
}
}
impl<'ctx> FloatKind<'ctx> for Float64 {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> FloatType<'ctx> {
ctx.f64_type()
}
}
#[derive(Debug, Clone, Copy)]
pub struct AnyFloat<'ctx>(FloatType<'ctx>);
impl<'ctx> FloatKind<'ctx> for AnyFloat<'ctx> {
fn get_float_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> FloatType<'ctx> {
self.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FloatModel<N>(pub N);
pub type Float<'ctx, N> = Instance<'ctx, FloatModel<N>>;
impl<'ctx, N: FloatKind<'ctx>> Model<'ctx> for FloatModel<N> {
type Value = FloatValue<'ctx>;
type Type = FloatType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_float_type(generator, ctx)
}
fn check_type<T: inkwell::types::BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = FloatType::try_from(ty) else {
return Err(ModelError(format!("Expecting FloatType, but got {ty:?}")));
};
let exp_ty = self.0.get_float_type(generator, ctx);
// TODO: Inkwell does not have get_bit_width for FloatType?
// TODO: Quick hack for now, but does this actually work?
if ty != exp_ty {
return Err(ModelError(format!("Expecting {exp_ty:?}, but got {ty:?}")));
}
Ok(())
}
}

View File

@ -1,125 +0,0 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
types::{BasicMetadataTypeEnum, BasicType, FunctionType},
values::{AnyValue, BasicMetadataValueEnum, BasicValue, BasicValueEnum, CallSiteValue},
};
use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
// When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}".
// When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64".
#[must_use]
pub fn get_sizet_dependent_function_name<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'_, '_>,
name: &str,
) -> String {
let mut name = name.to_owned();
match generator.get_size_type(ctx.ctx).get_bit_width() {
32 => {}
64 => name.push_str("64"),
bit_width => {
panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
}
}
name
}
#[derive(Debug, Clone, Copy)]
struct Arg<'ctx> {
ty: BasicMetadataTypeEnum<'ctx>,
val: BasicMetadataValueEnum<'ctx>,
}
/// A structure to construct & call an LLVM function.
///
/// This is a helper to reduce IRRT Inkwell function call boilerplate
// TODO: Remove the lifetimes somehow? There is 4 of them.
pub struct CallFunction<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> {
generator: &'d mut G,
ctx: &'b CodeGenContext<'ctx, 'a>,
/// Function name
name: &'c str,
/// Call arguments
args: Vec<Arg<'ctx>>,
/// LLVM function Attributes
attrs: Vec<&'static str>,
}
impl<'ctx, 'a, 'b, 'c, 'd, G: CodeGenerator + ?Sized> CallFunction<'ctx, 'a, 'b, 'c, 'd, G> {
pub fn begin(generator: &'d mut G, ctx: &'b CodeGenContext<'ctx, 'a>, name: &'c str) -> Self {
CallFunction { generator, ctx, name, args: Vec::new(), attrs: Vec::new() }
}
/// Push a list of LLVM function attributes to the function declaration.
#[must_use]
pub fn attrs(mut self, attrs: Vec<&'static str>) -> Self {
self.attrs = attrs;
self
}
/// Push a call argument to the function call.
#[allow(clippy::needless_pass_by_value)]
#[must_use]
pub fn arg<M: Model<'ctx>>(mut self, arg: Instance<'ctx, M>) -> Self {
let arg = Arg {
ty: arg.model.get_type(self.generator, self.ctx.ctx).as_basic_type_enum().into(),
val: arg.value.as_basic_value_enum().into(),
};
self.args.push(arg);
self
}
/// Call the function and expect the function to return a value of type of `return_model`.
#[must_use]
pub fn returning<M: Model<'ctx>>(self, name: &str, return_model: M) -> Instance<'ctx, M> {
let ret_ty = return_model.get_type(self.generator, self.ctx.ctx);
let ret = self.get_function(|tys| ret_ty.fn_type(tys, false), name);
let ret = BasicValueEnum::try_from(ret.as_any_value_enum()).unwrap(); // Must work
let ret = return_model.check_value(self.generator, self.ctx.ctx, ret).unwrap(); // Must work
ret
}
/// Like [`CallFunction::returning_`] but `return_model` is automatically inferred.
#[must_use]
pub fn returning_auto<M: Model<'ctx> + Default>(self, name: &str) -> Instance<'ctx, M> {
self.returning(name, M::default())
}
/// Call the function and expect the function to return a void-type.
pub fn returning_void(self) {
let ret_ty = self.ctx.ctx.void_type();
let _ = self.get_function(|tys| ret_ty.fn_type(tys, false), "");
}
fn get_function<F>(&self, make_fn_type: F, return_value_name: &str) -> CallSiteValue<'ctx>
where
F: FnOnce(&[BasicMetadataTypeEnum<'ctx>]) -> FunctionType<'ctx>,
{
// Get the LLVM function.
let func = self.ctx.module.get_function(self.name).unwrap_or_else(|| {
// Declare the function if it doesn't exist.
let tys = self.args.iter().map(|arg| arg.ty).collect_vec();
let func_type = make_fn_type(&tys);
let func = self.ctx.module.add_function(self.name, func_type, None);
for attr in &self.attrs {
func.add_attribute(
AttributeLoc::Function,
self.ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
let vals = self.args.iter().map(|arg| arg.val).collect_vec();
self.ctx.builder.build_call(func, &vals, return_value_name).unwrap()
}
}

View File

@ -1,275 +0,0 @@
use std::fmt;
use inkwell::{context::Context, types::IntType, values::IntValue, IntPredicate};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
pub trait IntKind<'ctx>: fmt::Debug + Clone + Copy {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx>;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Bool;
#[derive(Debug, Clone, Copy, Default)]
pub struct Byte;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int32;
#[derive(Debug, Clone, Copy, Default)]
pub struct Int64;
#[derive(Debug, Clone, Copy, Default)]
pub struct SizeT;
impl<'ctx> IntKind<'ctx> for Bool {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.bool_type()
}
}
impl<'ctx> IntKind<'ctx> for Byte {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i8_type()
}
}
impl<'ctx> IntKind<'ctx> for Int32 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i32_type()
}
}
impl<'ctx> IntKind<'ctx> for Int64 {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
ctx.i64_type()
}
}
impl<'ctx> IntKind<'ctx> for SizeT {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> IntType<'ctx> {
generator.get_size_type(ctx)
}
}
#[derive(Debug, Clone, Copy)]
pub struct AnyInt<'ctx>(pub IntType<'ctx>);
impl<'ctx> IntKind<'ctx> for AnyInt<'ctx> {
fn get_int_type<G: CodeGenerator + ?Sized>(
&self,
_generator: &G,
_ctx: &'ctx Context,
) -> IntType<'ctx> {
self.0
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IntModel<N>(pub N);
pub type Int<'ctx, N> = Instance<'ctx, IntModel<N>>;
impl<'ctx, N: IntKind<'ctx>> Model<'ctx> for IntModel<N> {
type Value = IntValue<'ctx>;
type Type = IntType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_int_type(generator, ctx)
}
fn check_type<T: inkwell::types::BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = IntType::try_from(ty) else {
return Err(ModelError(format!("Expecting IntType, but got {ty:?}")));
};
let exp_ty = self.0.get_int_type(generator, ctx);
if ty.get_bit_width() != exp_ty.get_bit_width() {
return Err(ModelError(format!(
"Expecting IntType to have {} bit(s), but got {} bit(s)",
exp_ty.get_bit_width(),
ty.get_bit_width()
)));
}
Ok(())
}
}
impl<'ctx, N: IntKind<'ctx>> IntModel<N> {
pub fn constant<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
value: u64,
) -> Int<'ctx, N> {
let value = self.get_type(generator, ctx).const_int(value, false);
self.believe_value(value)
}
pub fn const_0<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, N> {
self.constant(generator, ctx, 0)
}
pub fn const_1<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, N> {
self.constant(generator, ctx, 1)
}
pub fn const_all_1s<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, N> {
let value = self.get_type(generator, ctx).const_all_ones();
self.believe_value(value)
}
pub fn s_extend_or_bit_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx
.builder
.build_int_s_extend_or_bit_cast(value, self.get_type(generator, ctx.ctx), name)
.unwrap();
self.believe_value(value)
}
pub fn truncate<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
name: &str,
) -> Int<'ctx, N> {
let value =
ctx.builder.build_int_truncate(value, self.get_type(generator, ctx.ctx), name).unwrap();
self.believe_value(value)
}
}
impl IntModel<Bool> {
#[must_use]
pub fn const_false<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, Bool> {
self.constant(generator, ctx, 0)
}
#[must_use]
pub fn const_true<'ctx, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, Bool> {
self.constant(generator, ctx, 1)
}
}
impl<'ctx, N: IntKind<'ctx>> Int<'ctx, N> {
pub fn s_extend_or_bit_cast<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
name: &str,
) -> Int<'ctx, NewN> {
IntModel(to_int_kind).s_extend_or_bit_cast(generator, ctx, self.value, name)
}
pub fn truncate<NewN: IntKind<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
to_int_kind: NewN,
name: &str,
) -> Int<'ctx, NewN> {
IntModel(to_int_kind).truncate(generator, ctx, self.value, name)
}
#[must_use]
pub fn add(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_add(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
#[must_use]
pub fn sub(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_sub(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
#[must_use]
pub fn mul(
&self,
ctx: &CodeGenContext<'ctx, '_>,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, N> {
let value = ctx.builder.build_int_mul(self.value, other.value, name).unwrap();
self.model.believe_value(value)
}
pub fn compare(
&self,
ctx: &CodeGenContext<'ctx, '_>,
op: IntPredicate,
other: Int<'ctx, N>,
name: &str,
) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_int_compare(op, self.value, other.value, name).unwrap();
bool_model.believe_value(value)
}
}

View File

@ -1,17 +0,0 @@
mod any;
mod array;
mod core;
mod float;
pub mod function;
mod int;
mod ptr;
mod structure;
pub mod util;
pub use any::*;
pub use array::*;
pub use core::*;
pub use float::*;
pub use int::*;
pub use ptr::*;
pub use structure::*;

View File

@ -1,145 +0,0 @@
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
#[derive(Debug, Clone, Copy, Default)]
pub struct PtrModel<Element>(pub Element);
pub type Ptr<'ctx, Element> = Instance<'ctx, PtrModel<Element>>;
impl<'ctx, Element: Model<'ctx>> Model<'ctx> for PtrModel<Element> {
type Value = PointerValue<'ctx>;
type Type = PointerType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_type(generator, ctx).ptr_type(AddressSpace::default())
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = PointerType::try_from(ty) else {
return Err(ModelError(format!("Expecting PointerType, but got {ty:?}")));
};
let elem_ty = ty.get_element_type();
let Ok(elem_ty) = BasicTypeEnum::try_from(elem_ty) else {
return Err(ModelError(format!(
"Expecting pointer element type to be a BasicTypeEnum, but got {elem_ty:?}"
)));
};
// TODO: inkwell `get_element_type()` will be deprecated.
// Remove the check for `get_element_type()` when the time comes.
self.0
.check_type(generator, ctx, elem_ty)
.map_err(|err| err.under_context("a PointerType"))?;
Ok(())
}
}
impl<'ctx, Element: Model<'ctx>> PtrModel<Element> {
/// Return a ***constant*** nullptr.
pub fn nullptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Ptr<'ctx, Element> {
let ptr = self.get_type(generator, ctx).const_null();
self.believe_value(ptr)
}
/// Cast a pointer into this model with [`inkwell::builder::Builder::build_pointer_cast`]
pub fn pointer_cast<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Element> {
let ptr =
ctx.builder.build_pointer_cast(ptr, self.get_type(generator, ctx.ctx), name).unwrap();
self.believe_value(ptr)
}
}
impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, Element> {
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`].
#[must_use]
pub fn offset<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
offset: IntValue<'ctx>,
name: &str,
) -> Ptr<'ctx, Element> {
let new_ptr =
unsafe { ctx.builder.build_in_bounds_gep(self.value, &[offset], name).unwrap() };
self.model.check_value(generator, ctx.ctx, new_ptr).unwrap()
}
/// Offset the pointer by [`inkwell::builder::Builder::build_in_bounds_gep`] by a constant offset.
#[must_use]
pub fn offset_const<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
offset: u64,
name: &str,
) -> Ptr<'ctx, Element> {
let offset = ctx.ctx.i32_type().const_int(offset, false);
self.offset(generator, ctx, offset, name)
}
/// Load the value with [`inkwell::builder::Builder::build_load`].
pub fn load<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
name: &str,
) -> Instance<'ctx, Element> {
let value = ctx.builder.build_load(self.value, name).unwrap();
self.model.0.check_value(generator, ctx.ctx, value).unwrap() // If unwrap() panics, there is a logic error.
}
/// Store a value with [`inkwell::builder::Builder::build_store`].
pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, value: Instance<'ctx, Element>) {
ctx.builder.build_store(self.value, value.value).unwrap();
}
/// Return a casted pointer of element type `NewElement` with [`inkwell::builder::Builder::build_pointer_cast`].
pub fn pointer_cast<NewElement: Model<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
new_model: NewElement,
name: &str,
) -> Ptr<'ctx, NewElement> {
PtrModel(new_model).pointer_cast(generator, ctx, self.value, name)
}
/// Check if the pointer is null with [`inkwell::builder::Builder::build_is_null`].
pub fn is_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_is_null(self.value, name).unwrap();
bool_model.believe_value(value)
}
/// Check if the pointer is not null with [`inkwell::builder::Builder::build_is_not_null`].
pub fn is_not_null(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
let value = ctx.builder.build_is_not_null(self.value, name).unwrap();
bool_model.believe_value(value)
}
}

View File

@ -1,222 +0,0 @@
use std::fmt;
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, StructType},
values::StructValue,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
use super::*;
#[derive(Debug, Clone, Copy)]
pub struct GepField<M> {
pub gep_index: u64,
pub name: &'static str,
pub model: M,
}
pub trait FieldTraversal<'ctx> {
type Out<M>;
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M>;
/// Like [`FieldTraversal::visit`] but [`Model`] is automatically inferred from [`Default`] trait.
fn add_auto<M: Model<'ctx> + Default>(&mut self, name: &'static str) -> Self::Out<M> {
self.add(name, M::default())
}
}
pub struct GepFieldTraversal {
gep_index_counter: u64,
}
impl<'ctx> FieldTraversal<'ctx> for GepFieldTraversal {
type Out<M> = GepField<M>;
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
let gep_index = self.gep_index_counter;
self.gep_index_counter += 1;
Self::Out { gep_index, name, model }
}
}
struct TypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a G,
ctx: &'ctx Context,
field_types: Vec<BasicTypeEnum<'ctx>>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx> for TypeFieldTraversal<'ctx, 'a, G> {
type Out<M> = ();
fn add<M: Model<'ctx>>(&mut self, _name: &'static str, model: M) -> Self::Out<M> {
let t = model.get_type(self.generator, self.ctx).as_basic_type_enum();
self.field_types.push(t);
}
}
struct CheckTypeFieldTraversal<'ctx, 'a, G: CodeGenerator + ?Sized> {
generator: &'a mut G,
ctx: &'ctx Context,
index: u32,
scrutinee: StructType<'ctx>,
errors: Vec<ModelError>,
}
impl<'ctx, 'a, G: CodeGenerator + ?Sized> FieldTraversal<'ctx>
for CheckTypeFieldTraversal<'ctx, 'a, G>
{
type Out<M> = ();
fn add<M: Model<'ctx>>(&mut self, name: &'static str, model: M) -> Self::Out<M> {
let i = self.index;
self.index += 1;
if let Some(t) = self.scrutinee.get_field_type_at_index(i) {
if let Err(err) = model.check_type(self.generator, self.ctx, t) {
self.errors.push(err.under_context(format!("field #{i} '{name}'").as_str()));
}
} // Otherwise, it will be caught
}
}
pub trait StructKind<'ctx>: fmt::Debug + Clone + Copy {
type Fields<F: FieldTraversal<'ctx>>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F>;
fn fields(&self) -> Self::Fields<GepFieldTraversal> {
self.traverse_fields(&mut GepFieldTraversal { gep_index_counter: 0 })
}
fn get_struct_type<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &'ctx Context,
) -> StructType<'ctx> {
let mut traversal = TypeFieldTraversal { generator, ctx, field_types: Vec::new() };
self.traverse_fields(&mut traversal);
ctx.struct_type(&traversal.field_types, false)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct StructModel<S>(pub S);
pub type Struct<'ctx, S> = Instance<'ctx, StructModel<S>>;
impl<'ctx, S: StructKind<'ctx>> Model<'ctx> for StructModel<S> {
type Value = StructValue<'ctx>;
type Type = StructType<'ctx>;
fn get_type<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &'ctx Context) -> Self::Type {
self.0.get_struct_type(generator, ctx)
}
fn check_type<T: BasicType<'ctx>, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
ty: T,
) -> Result<(), ModelError> {
let ty = ty.as_basic_type_enum();
let Ok(ty) = StructType::try_from(ty) else {
return Err(ModelError(format!("Expecting StructType, but got {ty:?}")));
};
let mut traversal =
CheckTypeFieldTraversal { generator, ctx, index: 0, errors: Vec::new(), scrutinee: ty };
self.0.traverse_fields(&mut traversal);
let exp_num_fields = traversal.index;
let got_num_fields = u32::try_from(ty.get_field_types().len()).unwrap();
if exp_num_fields != got_num_fields {
return Err(ModelError(format!(
"Expecting StructType with {exp_num_fields} field(s), but got {got_num_fields}"
)));
}
if !traversal.errors.is_empty() {
return Err(traversal.errors[0].clone()); // TODO: Return other errors as well
}
Ok(())
}
}
impl<'ctx, S: StructKind<'ctx>> Struct<'ctx, S> {
pub fn get_field<G: CodeGenerator + ?Sized, M, GetField>(
&self,
generator: &mut G,
ctx: &'ctx Context,
get_field: GetField,
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
let field = get_field(self.model.0.fields());
let val = self.value.get_field_at_index(field.gep_index as u32).unwrap();
field.model.check_value(generator, ctx, val).unwrap()
}
}
impl<'ctx, S: StructKind<'ctx>> Ptr<'ctx, StructModel<S>> {
pub fn gep<M, GetField>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
) -> Ptr<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
let field = get_field(self.model.0 .0.fields());
let llvm_i32 = ctx.ctx.i32_type(); // i64 would segfault
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.value,
&[llvm_i32.const_zero(), llvm_i32.const_int(field.gep_index, false)],
field.name,
)
.unwrap()
};
let ptr_model = PtrModel(field.model);
ptr_model.believe_value(ptr)
}
/// Convenience function equivalent to `.gep(...).load(...)`.
pub fn get<M, GetField, G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
name: &str,
) -> Instance<'ctx, M>
where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
self.gep(ctx, get_field).load(generator, ctx, name)
}
/// Convenience function equivalent to `.gep(...).store(...)`.
pub fn set<M, GetField>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
get_field: GetField,
value: Instance<'ctx, M>,
) where
M: Model<'ctx>,
GetField: FnOnce(S::Fields<GepFieldTraversal>) -> GepField<M>,
{
self.gep(ctx, get_field).store(ctx, value);
}
}
// TODO: Add an opaque struct type?

View File

@ -1,62 +0,0 @@
use inkwell::{types::BasicType, values::IntValue};
/// `llvm.memcpy` but under the [`Model`] abstraction
use crate::codegen::{
llvm_intrinsics::call_memcpy_generic,
stmt::{gen_for_callback_incrementing, BreakContinueHooks},
CodeGenContext, CodeGenerator,
};
use super::*;
/// Convenience function.
///
/// Like [`call_memcpy_generic`] but with model abstractions and `is_volatile` set to `false`.
pub fn call_memcpy_model<'ctx, Item: Model<'ctx> + Default, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_array: Ptr<'ctx, Item>,
src_array: Ptr<'ctx, Item>,
num_items: IntValue<'ctx>,
) {
let itemsize = Item::default().get_type(generator, ctx.ctx).size_of().unwrap();
let totalsize = ctx.builder.build_int_mul(itemsize, num_items, "totalsize").unwrap(); // TODO: Int types may not match.
let is_volatile = ctx.ctx.bool_type().const_zero();
call_memcpy_generic(ctx, dst_array.value, src_array.value, totalsize, is_volatile);
}
/// Like [`gen_for_callback_incrementing`] with [`Model`] abstractions.
/// The [`IntKind`] is automatically inferred.
pub fn gen_for_model_auto<'ctx, 'a, G, F, I>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
start: Int<'ctx, I>,
stop: Int<'ctx, I>,
step: Int<'ctx, I>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
Int<'ctx, I>,
) -> Result<(), String>,
I: IntKind<'ctx> + Default,
{
let int_model = IntModel(I::default());
gen_for_callback_incrementing(
generator,
ctx,
None,
start.value,
(stop.value, false),
|g, ctx, hooks, i| {
let i = int_model.believe_value(i);
body(g, ctx, hooks, i)
},
step.value,
)
}

View File

@ -1,23 +1,32 @@
use crate::{ use inkwell::{
codegen::{ types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType},
classes::{ values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, AddressSpace, IntPredicate, OptimizationLevel,
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, };
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
}, use nac3parser::ast::{Operator, StrRef};
expr::gen_binop_expr_with_values,
irrt::{ use super::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast, expr::gen_binop_expr_with_values,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, irrt::{
call_ndarray_calc_size, calculate_len_for_slice_range, call_ndarray_calc_broadcast,
}, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size,
llvm_intrinsics::{self, call_memcpy_generic},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
}, },
llvm_intrinsics::{self, call_memcpy_generic},
macros::codegen_unreachable,
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
types::{ListType, NDArrayType, ProxyType},
values::{
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
},
CodeGenContext, CodeGenerator,
};
use crate::{
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{ toplevel::{
helper::PrimDef, helper::{arraylike_flatten_element_type, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
DefinitionId, DefinitionId,
}, },
@ -26,16 +35,6 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum}, typedef::{FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::{
types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel,
};
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
values::BasicValue,
};
use nac3parser::ast::{Operator, StrRef};
/// Creates an uninitialized `NDArray` instance. /// Creates an uninitialized `NDArray` instance.
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
@ -43,6 +42,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
) -> Result<NDArrayValue<'ctx>, String> { ) -> Result<NDArrayValue<'ctx>, String> {
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -55,7 +55,7 @@ fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None)) Ok(NDArrayValue::from_pointer_value(ndarray, llvm_elem_ty, llvm_usize, None))
} }
/// Creates an `NDArray` instance from a dynamic shape. /// Creates an `NDArray` instance from a dynamic shape.
@ -128,7 +128,7 @@ where
ndarray.store_ndims(ctx, generator, num_dims); ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
// Copy the dimension sizes from shape to ndarray.dims // Copy the dimension sizes from shape to ndarray.dims
let shape_len = shape_len_fn(generator, ctx, shape)?; let shape_len = shape_len_fn(generator, ctx, shape)?;
@ -144,7 +144,7 @@ where
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_pdim = let ndarray_pdim =
unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; unsafe { ndarray.shape().ptr_offset_unchecked(ctx, generator, &i, None) };
ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap();
@ -195,12 +195,12 @@ pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
ndarray.store_ndims(ctx, generator, num_dims); ndarray.store_ndims(ctx, generator, num_dims);
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); ndarray.create_shape(ctx, llvm_usize, ndarray_num_dims);
for (i, &shape_dim) in shape.iter().enumerate() { for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe { let ndarray_dim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked( ndarray.shape().ptr_offset_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(i as u64, true), &llvm_usize.const_int(i as u64, true),
@ -229,7 +229,7 @@ fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator), &ndarray.shape().as_slice_value(ctx, generator),
(None, None), (None, None),
); );
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
@ -257,9 +257,9 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into() ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "").value.into() ctx.gen_string(generator, "").into()
} else { } else {
unreachable!() codegen_unreachable!(ctx)
} }
} }
@ -285,9 +285,9 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into() ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
ctx.gen_string(generator, "1").value.into() ctx.gen_string(generator, "1").into()
} else { } else {
unreachable!() codegen_unreachable!(ctx)
} }
} }
@ -315,11 +315,11 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
match shape { match shape {
BasicValueEnum::PointerValue(shape_list_ptr) BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
{ {
// 1. A list of ints; e.g., `np.empty([600, 800, 3])` // 1. A list of ints; e.g., `np.empty([600, 800, 3])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
create_ndarray_dyn_shape( create_ndarray_dyn_shape(
generator, generator,
ctx, ctx,
@ -355,7 +355,7 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
} }
_ => unreachable!(), _ => codegen_unreachable!(ctx),
} }
} }
@ -380,7 +380,7 @@ where
let ndarray_num_elems = call_ndarray_calc_size( let ndarray_num_elems = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator), &ndarray.shape().as_slice_value(ctx, generator),
(None, None), (None, None),
); );
@ -474,8 +474,8 @@ fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
res: NDArrayValue<'ctx>, res: NDArrayValue<'ctx>,
lhs: (BasicValueEnum<'ctx>, bool), lhs: (Type, BasicValueEnum<'ctx>, bool),
rhs: (BasicValueEnum<'ctx>, bool), rhs: (Type, BasicValueEnum<'ctx>, bool),
value_fn: ValueFn, value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String> ) -> Result<NDArrayValue<'ctx>, String>
where where
@ -488,8 +488,8 @@ where
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (lhs_val, lhs_scalar) = lhs; let (lhs_ty, lhs_val, lhs_scalar) = lhs;
let (rhs_val, rhs_scalar) = rhs; let (rhs_ty, rhs_val, rhs_scalar) = rhs;
assert!( assert!(
!(lhs_scalar && rhs_scalar), !(lhs_scalar && rhs_scalar),
@ -500,12 +500,26 @@ where
// Assert that all ndarray operands are broadcastable to the target size // Assert that all ndarray operands are broadcastable to the target size
if !lhs_scalar { if !lhs_scalar {
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
let lhs_val = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(),
llvm_lhs_elem_ty,
llvm_usize,
None,
);
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
} }
if !rhs_scalar { if !rhs_scalar {
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(),
llvm_rhs_elem_ty,
llvm_usize,
None,
);
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
} }
@ -513,7 +527,14 @@ where
let lhs_elem = if lhs_scalar { let lhs_elem = if lhs_scalar {
lhs_val lhs_val
} else { } else {
let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
let lhs = NDArrayValue::from_pointer_value(
lhs_val.into_pointer_value(),
llvm_lhs_elem_ty,
llvm_usize,
None,
);
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
@ -522,7 +543,14 @@ where
let rhs_elem = if rhs_scalar { let rhs_elem = if rhs_scalar {
rhs_val rhs_val
} else { } else {
let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
let rhs = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(),
llvm_rhs_elem_ty,
llvm_usize,
None,
);
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
@ -626,7 +654,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
} else if fill_value.is_int_value() || fill_value.is_float_value() { } else if fill_value.is_int_value() || fill_value.is_float_value() {
fill_value fill_value
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
Ok(value) Ok(value)
@ -648,11 +676,15 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
let ndims = llvm_usize.const_int(1, false); let ndims = llvm_usize.const_int(1, false);
match list_elem_ty { match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { AnyTypeEnum::PointerType(ptr_ty)
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty))
} }
AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { AnyTypeEnum::PointerType(ptr_ty)
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
todo!("Getting ndims for list[ndarray] not supported") todo!("Getting ndims for list[ndarray] not supported")
} }
@ -664,16 +696,20 @@ fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>, (ty, value): (Type, BasicValueEnum<'ctx>),
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
match value { match value {
BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => { BasicValueEnum::PointerValue(v)
NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx) if NDArrayValue::is_representable(v, llvm_usize).is_ok() =>
{
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
NDArrayValue::from_pointer_value(v, llvm_elem_ty, llvm_usize, None).load_ndims(ctx)
} }
BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => { BasicValueEnum::PointerValue(v) if ListValue::is_representable(v, llvm_usize).is_ok() => {
llvm_ndlist_get_ndims(generator, ctx, v.get_type()) llvm_ndlist_get_ndims(generator, ctx, v.get_type())
} }
@ -685,7 +721,6 @@ fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>(
fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
src_lst: ListValue<'ctx>, src_lst: ListValue<'ctx>,
dim: u64, dim: u64,
@ -696,13 +731,15 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
let list_elem_ty = src_lst.get_type().element_type(); let list_elem_ty = src_lst.get_type().element_type();
match list_elem_ty { match list_elem_ty {
AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { AnyTypeEnum::PointerType(ptr_ty)
if ListType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
// The stride of elements in this dimension, i.e. the number of elements between arr[i] // The stride of elements in this dimension, i.e. the number of elements between arr[i]
// and arr[i + 1] in this dimension // and arr[i + 1] in this dimension
let stride = call_ndarray_calc_size( let stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&dst_arr.dim_sizes(), &dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None), (Some(llvm_usize.const_int(dim + 1, false)), None),
); );
@ -716,11 +753,25 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
|_, _| Ok(llvm_usize.const_int(1, false)), |_, _| Ok(llvm_usize.const_int(1, false)),
|generator, ctx, _, i| { |generator, ctx, _, i| {
let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); let offset = ctx.builder.build_int_mul(stride, i, "").unwrap();
let offset = ctx
.builder
.build_int_mul(
offset,
ctx.builder
.build_int_truncate_or_bit_cast(
dst_arr.get_type().element_type().size_of().unwrap(),
offset.get_type(),
"",
)
.unwrap(),
"",
)
.unwrap();
let dst_ptr = let dst_ptr =
unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() };
let nested_lst_elem = ListValue::from_ptr_val( let nested_lst_elem = ListValue::from_pointer_value(
unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) }
.into_pointer_value(), .into_pointer_value(),
llvm_usize, llvm_usize,
@ -730,7 +781,6 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
ndarray_from_ndlist_impl( ndarray_from_ndlist_impl(
generator, generator,
ctx, ctx,
elem_ty,
(dst_arr, dst_ptr), (dst_arr, dst_ptr),
nested_lst_elem, nested_lst_elem,
dim + 1, dim + 1,
@ -741,13 +791,15 @@ fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>(
)?; )?;
} }
AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { AnyTypeEnum::PointerType(ptr_ty)
if NDArrayType::is_representable(ptr_ty, llvm_usize).is_ok() =>
{
todo!("Not implemented for list[ndarray]") todo!("Not implemented for list[ndarray]")
} }
_ => { _ => {
let lst_len = src_lst.load_size(ctx, None); let lst_len = src_lst.load_size(ctx, None);
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap(); let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap();
let cpy_len = ctx let cpy_len = ctx
@ -802,8 +854,9 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
let object = object.into_pointer_value(); let object = object.into_pointer_value();
// object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims
if NDArrayValue::is_instance(object, llvm_usize).is_ok() { if NDArrayValue::is_representable(object, llvm_usize).is_ok() {
let object = NDArrayValue::from_ptr_val(object, llvm_usize, None); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let object = NDArrayValue::from_pointer_value(object, llvm_elem_ty, llvm_usize, None);
let ndarray = gen_if_else_expr_callback( let ndarray = gen_if_else_expr_callback(
generator, generator,
@ -865,7 +918,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
ndarray_sliced_copyto_impl( ndarray_sliced_copyto_impl(
generator, generator,
ctx, ctx,
elem_ty,
(ndarray, ndarray.data().base_ptr(ctx, generator)), (ndarray, ndarray.data().base_ptr(ctx, generator)),
(object, object.data().base_ptr(ctx, generator)), (object, object.data().base_ptr(ctx, generator)),
0, 0,
@ -877,16 +929,17 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
|_, _| Ok(Some(object.as_base_value())), |_, _| Ok(Some(object.as_base_value())),
)?; )?;
return Ok(NDArrayValue::from_ptr_val( return Ok(NDArrayValue::from_pointer_value(
ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), ndarray.map(BasicValueEnum::into_pointer_value).unwrap(),
llvm_elem_ty,
llvm_usize, llvm_usize,
None, None,
)); ));
} }
// Remaining case: TList // Remaining case: TList
assert!(ListValue::is_instance(object, llvm_usize).is_ok()); assert!(ListValue::is_representable(object, llvm_usize).is_ok());
let object = ListValue::from_ptr_val(object, llvm_usize, None); let object = ListValue::from_pointer_value(object, llvm_usize, None);
// The number of dimensions to prepend 1's to // The number of dimensions to prepend 1's to
let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type());
@ -941,7 +994,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
.build_store( .build_store(
lst, lst,
ctx.builder ctx.builder
.build_bitcast(object.as_base_value(), llvm_plist_i8, "") .build_bit_cast(object.as_base_value(), llvm_plist_i8, "")
.unwrap(), .unwrap(),
) )
.unwrap(); .unwrap();
@ -963,10 +1016,11 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
.builder .builder
.build_load(lst, "") .build_load(lst, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.map(|v| ctx.builder.build_bitcast(v, plist_plist_i8, "").unwrap()) .map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap())
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None); let this_dim =
ListValue::from_pointer_value(this_dim, llvm_usize, None);
// TODO: Assert this_dim.sz != 0 // TODO: Assert this_dim.sz != 0
@ -982,7 +1036,9 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder ctx.builder
.build_store( .build_store(
lst, lst,
ctx.builder.build_bitcast(next_dim, llvm_plist_i8, "").unwrap(), ctx.builder
.build_bit_cast(next_dim, llvm_plist_i8, "")
.unwrap(),
) )
.unwrap(); .unwrap();
@ -990,7 +1046,7 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
}, },
)?; )?;
let lst = ListValue::from_ptr_val( let lst = ListValue::from_pointer_value(
ctx.builder ctx.builder
.build_load(lst, "") .build_load(lst, "")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
@ -1010,7 +1066,6 @@ fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>(
ndarray_from_ndlist_impl( ndarray_from_ndlist_impl(
generator, generator,
ctx, ctx,
elem_ty,
(ndarray, ndarray.data().base_ptr(ctx, generator)), (ndarray, ndarray.data().base_ptr(ctx, generator)),
object, object,
0, 0,
@ -1071,19 +1126,18 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
/// Copies a slice of an [`NDArrayValue`] to another. /// Copies a slice of an [`NDArrayValue`] to another.
/// ///
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz` /// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz`
/// fields should be populated before calling this function. /// fields should be populated before calling this function.
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing /// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the destination array. /// dimensional slice in the destination array.
/// - `src_arr`: The [`NDArrayValue`] instance of the source array. /// - `src_arr`: The [`NDArrayValue`] instance of the source array.
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing /// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
/// dimensional slice in the source array. /// dimensional slice in the source array.
/// - `dim`: The index of the currently processing dimension. /// - `dim`: The index of the currently processing dimension.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to /// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. /// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
(dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
(src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>),
dim: u64, dim: u64,
@ -1092,14 +1146,16 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
assert_eq!(dst_arr.get_type().element_type(), src_arr.get_type().element_type());
let sizeof_elem = dst_arr.get_type().element_type().size_of().unwrap();
// If there are no (remaining) slice expressions, memcpy the entire dimension // If there are no (remaining) slice expressions, memcpy the entire dimension
if slices.is_empty() { if slices.is_empty() {
let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap();
let stride = call_ndarray_calc_size( let stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&src_arr.dim_sizes(), &src_arr.shape(),
(Some(llvm_usize.const_int(dim, false)), None), (Some(llvm_usize.const_int(dim, false)), None),
); );
let stride = let stride =
@ -1117,13 +1173,13 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
let src_stride = call_ndarray_calc_size( let src_stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&src_arr.dim_sizes(), &src_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None), (Some(llvm_usize.const_int(dim + 1, false)), None),
); );
let dst_stride = call_ndarray_calc_size( let dst_stride = call_ndarray_calc_size(
generator, generator,
ctx, ctx,
&dst_arr.dim_sizes(), &dst_arr.shape(),
(Some(llvm_usize.const_int(dim + 1, false)), None), (Some(llvm_usize.const_int(dim + 1, false)), None),
); );
@ -1146,9 +1202,29 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, _, src_i| { |generator, ctx, _, src_i| {
// Calculate the offset of the active slice // Calculate the offset of the active slice
let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap();
let src_data_offset = ctx
.builder
.build_int_mul(
src_data_offset,
ctx.builder
.build_int_cast(sizeof_elem, src_data_offset.get_type(), "")
.unwrap(),
"",
)
.unwrap();
let dst_i = let dst_i =
ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap();
let dst_data_offset = ctx
.builder
.build_int_mul(
dst_data_offset,
ctx.builder
.build_int_cast(sizeof_elem, dst_data_offset.get_type(), "")
.unwrap(),
"",
)
.unwrap();
let (src_ptr, dst_ptr) = unsafe { let (src_ptr, dst_ptr) = unsafe {
( (
@ -1160,7 +1236,6 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
ndarray_sliced_copyto_impl( ndarray_sliced_copyto_impl(
generator, generator,
ctx, ctx,
elem_ty,
(dst_arr, dst_ptr), (dst_arr, dst_ptr),
(src_arr, src_ptr), (src_arr, src_ptr),
dim + 1, dim + 1,
@ -1184,7 +1259,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to /// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
/// this dimension. The `start`/`stop` values of each slice must be positive indices. /// this dimension. The `start`/`stop` values of each slice must be positive indices.
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1203,7 +1278,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
&this, &this,
|_, ctx, shape| Ok(shape.load_ndims(ctx)), |_, ctx, shape| Ok(shape.load_ndims(ctx)),
|generator, ctx, shape, idx| unsafe { |generator, ctx, shape, idx| unsafe {
Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) Ok(shape.shape().get_typed_unchecked(ctx, generator, &idx, None))
}, },
)? )?
} else { } else {
@ -1211,7 +1286,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); ndarray.store_ndims(ctx, generator, this.load_ndims(ctx));
let ndims = this.load_ndims(ctx); let ndims = this.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndims); ndarray.create_shape(ctx, llvm_usize, ndims);
// Populate the first slices.len() dimensions by computing the size of each dim slice // Populate the first slices.len() dimensions by computing the size of each dim slice
for (i, (start, stop, step)) in slices.iter().enumerate() { for (i, (start, stop, step)) in slices.iter().enumerate() {
@ -1243,7 +1318,7 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap();
unsafe { unsafe {
ndarray.dim_sizes().set_typed_unchecked( ndarray.shape().set_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(i as u64, false), &llvm_usize.const_int(i as u64, false),
@ -1261,8 +1336,8 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
(this.load_ndims(ctx), false), (this.load_ndims(ctx), false),
|generator, ctx, _, idx| { |generator, ctx, _, idx| {
unsafe { unsafe {
let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); let dim_sz = this.shape().get_typed_unchecked(ctx, generator, &idx, None);
ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); ndarray.shape().set_typed_unchecked(ctx, generator, &idx, dim_sz);
} }
Ok(()) Ok(())
@ -1277,7 +1352,6 @@ pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
ndarray_sliced_copyto_impl( ndarray_sliced_copyto_impl(
generator, generator,
ctx, ctx,
elem_ty,
(ndarray, ndarray.data().base_ptr(ctx, generator)), (ndarray, ndarray.data().base_ptr(ctx, generator)),
(this, this.data().base_ptr(ctx, generator)), (this, this.data().base_ptr(ctx, generator)),
0, 0,
@ -1323,7 +1397,7 @@ where
&operand, &operand,
|_, ctx, v| Ok(v.load_ndims(ctx)), |_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe { |generator, ctx, v, idx| unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
}, },
) )
.unwrap() .unwrap()
@ -1349,7 +1423,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be /// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`. /// written to a new `ndarray`.
/// * `value_fn` - Function mapping the two input elements into the result. /// * `value_fn` - Function mapping the two input elements into the result.
/// ///
/// # Panic /// # Panic
@ -1360,8 +1434,8 @@ pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>(
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
elem_ty: Type, elem_ty: Type,
res: Option<NDArrayValue<'ctx>>, res: Option<NDArrayValue<'ctx>>,
lhs: (BasicValueEnum<'ctx>, bool), lhs: (Type, BasicValueEnum<'ctx>, bool),
rhs: (BasicValueEnum<'ctx>, bool), rhs: (Type, BasicValueEnum<'ctx>, bool),
value_fn: ValueFn, value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String> ) -> Result<NDArrayValue<'ctx>, String>
where where
@ -1374,8 +1448,8 @@ where
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (lhs_val, lhs_scalar) = lhs; let (lhs_ty, lhs_val, lhs_scalar) = lhs;
let (rhs_val, rhs_scalar) = rhs; let (rhs_ty, rhs_val, rhs_scalar) = rhs;
assert!( assert!(
!(lhs_scalar && rhs_scalar), !(lhs_scalar && rhs_scalar),
@ -1386,10 +1460,22 @@ where
let ndarray = res.unwrap_or_else(|| { let ndarray = res.unwrap_or_else(|| {
if lhs_scalar && rhs_scalar { if lhs_scalar && rhs_scalar {
let lhs_val = let lhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, lhs_ty);
NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); let llvm_lhs_elem_ty = ctx.get_llvm_type(generator, lhs_dtype);
let rhs_val = let lhs_val = NDArrayValue::from_pointer_value(
NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); lhs_val.into_pointer_value(),
llvm_lhs_elem_ty,
llvm_usize,
None,
);
let rhs_dtype = arraylike_flatten_element_type(&mut ctx.unifier, rhs_ty);
let llvm_rhs_elem_ty = ctx.get_llvm_type(generator, rhs_dtype);
let rhs_val = NDArrayValue::from_pointer_value(
rhs_val.into_pointer_value(),
llvm_rhs_elem_ty,
llvm_usize,
None,
);
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
@ -1405,8 +1491,14 @@ where
) )
.unwrap() .unwrap()
} else { } else {
let ndarray = NDArrayValue::from_ptr_val( let dtype = arraylike_flatten_element_type(
&mut ctx.unifier,
if lhs_scalar { rhs_ty } else { lhs_ty },
);
let llvm_elem_ty = ctx.get_llvm_type(generator, dtype);
let ndarray = NDArrayValue::from_pointer_value(
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
llvm_elem_ty,
llvm_usize, llvm_usize,
None, None,
); );
@ -1418,7 +1510,7 @@ where
&ndarray, &ndarray,
|_, ctx, v| Ok(v.load_ndims(ctx)), |_, ctx, v| Ok(v.load_ndims(ctx)),
|generator, ctx, v, idx| unsafe { |generator, ctx, v, idx| unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) Ok(v.shape().get_typed_unchecked(ctx, generator, &idx, None))
}, },
) )
.unwrap() .unwrap()
@ -1436,7 +1528,7 @@ where
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be /// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`. /// written to a new `ndarray`.
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1479,10 +1571,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
if let Some(res) = res { if let Some(res) = res {
let res_ndims = res.load_ndims(ctx); let res_ndims = res.load_ndims(ctx);
let res_dim0 = unsafe { let res_dim0 = unsafe {
res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) res.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
let res_dim1 = unsafe { let res_dim1 = unsafe {
res.dim_sizes().get_typed_unchecked( res.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1490,10 +1582,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
) )
}; };
let lhs_dim0 = unsafe { let lhs_dim0 = unsafe {
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
let rhs_dim1 = unsafe { let rhs_dim1 = unsafe {
rhs.dim_sizes().get_typed_unchecked( rhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1542,15 +1634,10 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let lhs_dim1 = unsafe { let lhs_dim1 = unsafe {
lhs.dim_sizes().get_typed_unchecked( lhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
}; };
let rhs_dim0 = unsafe { let rhs_dim0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
// lhs.dims[1] == rhs.dims[0] // lhs.dims[1] == rhs.dims[0]
@ -1589,7 +1676,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
}, },
|generator, ctx| { |generator, ctx| {
Ok(Some(unsafe { Ok(Some(unsafe {
lhs.dim_sizes().get_typed_unchecked( lhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_zero(), &llvm_usize.const_zero(),
@ -1599,7 +1686,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
}, },
|generator, ctx| { |generator, ctx| {
Ok(Some(unsafe { Ok(Some(unsafe {
rhs.dim_sizes().get_typed_unchecked( rhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1626,7 +1713,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
let common_dim = { let common_dim = {
let lhs_idx1 = unsafe { let lhs_idx1 = unsafe {
lhs.dim_sizes().get_typed_unchecked( lhs.shape().get_typed_unchecked(
ctx, ctx,
generator, generator,
&llvm_usize.const_int(1, false), &llvm_usize.const_int(1, false),
@ -1634,7 +1721,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
) )
}; };
let rhs_idx0 = unsafe { let rhs_idx0 = unsafe {
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) rhs.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; };
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
@ -1965,11 +2052,18 @@ pub fn gen_ndarray_copy<'ctx>(
let this_arg = let this_arg =
obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?;
let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty);
ndarray_copy_impl( ndarray_copy_impl(
generator, generator,
context, context,
this_elem_ty, this_elem_ty,
NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None), NDArrayValue::from_pointer_value(
this_arg.into_pointer_value(),
llvm_elem_ty,
llvm_usize,
None,
),
) )
.map(NDArrayValue::into) .map(NDArrayValue::into)
} }
@ -1988,6 +2082,7 @@ pub fn gen_ndarray_fill<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx); let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0; let this_ty = obj.as_ref().unwrap().0;
let this_elem_ty = arraylike_flatten_element_type(&mut context.unifier, this_ty);
let this_arg = obj let this_arg = obj
.as_ref() .as_ref()
.unwrap() .unwrap()
@ -1998,10 +2093,12 @@ pub fn gen_ndarray_fill<'ctx>(
let value_ty = fun.0.args[0].ty; let value_ty = fun.0.args[0].ty;
let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?;
let llvm_elem_ty = context.get_llvm_type(generator, this_elem_ty);
ndarray_fill_flattened( ndarray_fill_flattened(
generator, generator,
context, context,
NDArrayValue::from_ptr_val(this_arg, llvm_usize, None), NDArrayValue::from_pointer_value(this_arg, llvm_elem_ty, llvm_usize, None),
|generator, ctx, _| { |generator, ctx, _| {
let value = if value_arg.is_pointer_value() { let value = if value_arg.is_pointer_value() {
let llvm_i1 = ctx.ctx.bool_type(); let llvm_i1 = ctx.ctx.bool_type();
@ -2020,7 +2117,7 @@ pub fn gen_ndarray_fill<'ctx>(
} else if value_arg.is_int_value() || value_arg.is_float_value() { } else if value_arg.is_int_value() || value_arg.is_float_value() {
value_arg value_arg
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
Ok(value) Ok(value)
@ -2042,8 +2139,9 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
// Dimensions are reversed in the transposed array // Dimensions are reversed in the transposed array
let out = create_ndarray_dyn_shape( let out = create_ndarray_dyn_shape(
@ -2058,7 +2156,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.builder .builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "") .build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap(); .unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) } unsafe { Ok(n.shape().get_typed_unchecked(ctx, generator, &new_idx, None)) }
}, },
) )
.unwrap(); .unwrap();
@ -2095,7 +2193,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "") .build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap(); .unwrap();
let dim = unsafe { let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None) n1.shape().get_typed_unchecked(ctx, generator, &ndim_rev, None)
}; };
let rem_idx_val = let rem_idx_val =
@ -2129,7 +2227,8 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
Ok(out.as_base_value().into()) Ok(out.as_base_value().into())
} else { } else {
unreachable!( codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'", "{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty)) format!("'{}'", ctx.unifier.stringify(x1_ty))
) )
@ -2140,11 +2239,12 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
/// ///
/// * `x1` - `NDArray` to reshape. /// * `x1` - `NDArray` to reshape.
/// * `shape` - The `shape` parameter used to construct the new `NDArray`. /// * `shape` - The `shape` parameter used to construct the new `NDArray`.
/// Just like numpy, the `shape` argument can be: /// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` /// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` /// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` /// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative ///
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2159,8 +2259,9 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
if let BasicValueEnum::PointerValue(n1) = x1 { if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n1 = NDArrayValue::from_pointer_value(n1, llvm_elem_ty, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
@ -2169,11 +2270,11 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
let out = match shape { let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr) BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => if ListValue::is_representable(shape_list_ptr, llvm_usize).is_ok() =>
{ {
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])` // 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); let shape_list = ListValue::from_pointer_value(shape_list_ptr, llvm_usize, None);
// Check for -1 in dimensions // Check for -1 in dimensions
gen_for_callback_incrementing( gen_for_callback_incrementing(
generator, generator,
@ -2370,7 +2471,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
.into_int_value(); .into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
} }
_ => unreachable!(), _ => codegen_unreachable!(ctx),
} }
.unwrap(); .unwrap();
@ -2388,7 +2489,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
); );
// The new shape must be compatible with the old shape // The new shape must be compatible with the old shape
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None)); let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
ctx.make_assert( ctx.make_assert(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
@ -2414,7 +2515,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
Ok(out.as_base_value().into()) Ok(out.as_base_value().into())
} else { } else {
unreachable!( codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'", "{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty)) format!("'{}'", ctx.unifier.stringify(x1_ty))
) )
@ -2435,17 +2537,22 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_dot"; const FN_NAME: &str = "ndarray_dot";
let (x1_ty, x1) = x1; let (x1_ty, x1) = x1;
let (_, x2) = x2; let (x2_ty, x2) = x2;
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
match (x1, x2) { match (x1, x2) {
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x1_ty);
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); let n2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, x2_ty);
let llvm_n1_data_ty = ctx.get_llvm_type(generator, n1_dtype);
let llvm_n2_data_ty = ctx.get_llvm_type(generator, n2_dtype);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n1 = NDArrayValue::from_pointer_value(n1, llvm_n1_data_ty, llvm_usize, None);
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); let n2 = NDArrayValue::from_pointer_value(n2, llvm_n2_data_ty, llvm_usize, None);
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
ctx.make_assert( ctx.make_assert(
generator, generator,
@ -2482,7 +2589,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
.build_float_mul(e1, elem2.into_float_value(), "") .build_float_mul(e1, elem2.into_float_value(), "")
.unwrap() .unwrap()
.as_basic_value_enum(), .as_basic_value_enum(),
_ => unreachable!(), _ => codegen_unreachable!(ctx, "product: {}", elem1.get_type()),
}; };
let acc_val = ctx.builder.build_load(acc, "").unwrap(); let acc_val = ctx.builder.build_load(acc, "").unwrap();
let acc_val = match acc_val { let acc_val = match acc_val {
@ -2496,7 +2603,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
.build_float_add(e1, product.into_float_value(), "") .build_float_add(e1, product.into_float_value(), "")
.unwrap() .unwrap()
.as_basic_value_enum(), .as_basic_value_enum(),
_ => unreachable!(), _ => codegen_unreachable!(ctx, "acc_val: {}", acc_val.get_type()),
}; };
ctx.builder.build_store(acc, acc_val).unwrap(); ctx.builder.build_store(acc, acc_val).unwrap();
@ -2513,7 +2620,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => { (BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum()) Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
} }
_ => unreachable!( _ => codegen_unreachable!(
ctx,
"{FN_NAME}() not supported for '{}'", "{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty)) format!("'{}'", ctx.unifier.stringify(x1_ty))
), ),

View File

@ -1,210 +0,0 @@
// TODO: Replace numpy.rs
use inkwell::values::{BasicValue, BasicValueEnum};
use nac3parser::ast::StrRef;
use crate::{
codegen::object::{ndarray::scalar::split_scalar_or_ndarray, tuple::TupleObject},
symbol_resolver::ValueEnum,
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId},
typecheck::typedef::{FunSignature, Type},
};
use super::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative,
model::*,
object::{
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
AnyObject,
},
CodeGenContext, CodeGenerator,
};
/// Generates LLVM IR for `np.broadcast_to`.
pub fn gen_ndarray_broadcast_to<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 input
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
let input = AnyObject { ty: input_ty, value: input };
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Define models
let sizet_model = IntModel(SizeT);
// Extract broadcast_ndims, this is the only way to get the
// ndims of the ndarray result statically.
let (_, broadcast_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let broadcast_ndims = extract_ndims(&ctx.unifier, broadcast_ndims_ty);
// Process `input`
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
// Process `shape`
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape);
// NOTE: shape.size should equal to `broadcasted_ndims`.
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_util_assert_shape_no_negative(
generator,
ctx,
broadcast_ndims_llvm,
broadcast_shape,
);
// Create broadcast view
let broadcast_ndarray =
in_ndarray.broadcast_to(generator, ctx, broadcast_ndims, broadcast_shape);
Ok(broadcast_ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.reshape`.
pub fn gen_ndarray_reshape<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 2);
// Parse argument #1 input
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
let input = AnyObject { ty: input_ty, value: input };
// Parse argument #2 shape
let shape_ty = fun.0.args[1].ty;
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
let shape = AnyObject { ty: shape_ty, value: shape };
// Extract reshaped_ndims
let (_, reshaped_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
let reshaped_ndims = extract_ndims(&ctx.unifier, reshaped_ndims_ty);
// Process `input`
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
// Process the shape input from user and resolve negative indices.
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
// This is ensured by the typechecker.
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape);
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.arange`.
pub fn gen_ndarray_arange<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 len
let input_ty = fun.0.args[0].ty;
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?.into_int_value();
// Implementation
let input_dim = IntModel(SizeT).s_extend_or_bit_cast(generator, ctx, input, "input_dim");
let ndarray = NDArrayObject::from_np_arange(generator, ctx, input_dim);
Ok(ndarray.instance.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.size`.
pub fn gen_ndarray_size<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32, "size");
Ok(size.value.as_basic_value_enum())
}
/// Generates LLVM IR for `np.shape`.
pub fn gen_ndarray_shape<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Process ndarray
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
Ok(ndarray.make_shape_tuple(generator, ctx).value.as_basic_value_enum())
}
/// Generates LLVM IR for `<ndarray>.strides`.
pub fn gen_ndarray_strides<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<BasicValueEnum<'ctx>, String> {
// TODO: Code duplication: This function looks exactly like `gen_ndarray_shapes`.
assert!(obj.is_none());
assert_eq!(args.len(), 1);
// Parse argument #1 ndarray
let ndarray_ty = fun.0.args[0].ty;
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
// Process ndarray
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
for i in 0..ndarray.ndims {
let dim = ndarray
.instance
.get(generator, ctx, |f| f.strides, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
}
let strides = TupleObject::create(generator, ctx, objects, "strides");
Ok(strides.value.as_basic_value_enum())
}

View File

@ -1,121 +0,0 @@
use crate::{
codegen::{
irrt::{call_nac3_list_slice_assign, list_slice_assignment},
model::*,
object::ndarray::indexing::UserSlice,
structure::List,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
};
use super::{ndarray::indexing::RustUserSlice, AnyObject};
/// A NAC3 Python List object.
#[derive(Debug, Clone, Copy)]
pub struct ListObject<'ctx> {
/// Typechecker type of the list items
pub item_type: Type,
pub instance: Ptr<'ctx, StructModel<List<AnyModel<'ctx>>>>,
}
impl<'ctx> ListObject<'ctx> {
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
// Check typechecker type and extract `item_type`
let item_type = match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
}
_ => {
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty))
}
};
let item_model = AnyModel(ctx.get_llvm_type(generator, item_type));
let plist_model = PtrModel(StructModel(List { item: item_model }));
// Create object
let value = plist_model.check_value(generator, ctx.ctx, object.value).unwrap();
ListObject { item_type, instance: value }
}
/// Get the `items` field as an opaque pointer.
pub fn get_opaque_items_ptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, IntModel<Byte>> {
self.instance.get(generator, ctx, |f| f.items, "items").pointer_cast(
generator,
ctx,
IntModel(Byte),
"items_opaque",
)
}
/// Get the value of this [`ListObject`] as a list with opaque items.
///
/// This function allocates on the stack to create the list, but the
/// reference to the `items` are preserved.
pub fn get_opaque_list_ptr<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Ptr<'ctx, StructModel<List<IntModel<Byte>>>> {
let opaque_list_model = StructModel(List { item: IntModel(Byte) });
let opaque_list_ptr = opaque_list_model.alloca(generator, ctx, "opaque_list_ptr");
// Copy items pointer
let items = self.get_opaque_items_ptr(generator, ctx);
opaque_list_ptr.set(ctx, |f| f.items, items);
// Copy len
let len = self.instance.get(generator, ctx, |f| f.len, "len");
opaque_list_ptr.set(ctx, |f| f.len, len);
opaque_list_ptr
}
/// Get the `len()` of this list.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
self.instance.get(generator, ctx, |f| f.len, "list_len")
}
pub fn slice_assign_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
user_slice: &RustUserSlice<'ctx>,
source: ListObject<'ctx>,
) {
// Sanity check
assert!(ctx.unifier.unioned(self.item_type, source.item_type));
let user_slice_model = StructModel(UserSlice);
let puser_slice = user_slice_model.alloca(generator, ctx, "user_slice");
user_slice.write_to_user_slice(generator, ctx, puser_slice);
let itemsize = self.instance.model.get_type(generator, ctx.ctx).size_of();
call_nac3_list_slice_assign(
generator,
ctx,
self.get_opaque_list_ptr(generator, ctx),
source.instance.value,
itemsize,
user_slice,
);
todo!()
}
}

View File

@ -1,608 +0,0 @@
use inkwell::{
values::{BasicValue, BasicValueEnum, FloatValue, IntValue},
FloatPredicate, IntPredicate,
};
use itertools::Itertools;
use list::ListObject;
use ndarray::{NDArrayObject, NDArrayOut};
use range::RangeObject;
use tuple::TupleObject;
use crate::{
toplevel::helper::PrimDef,
typecheck::typedef::{Type, TypeEnum},
};
use super::{llvm_intrinsics, model::*, CodeGenContext, CodeGenerator};
pub mod list;
pub mod ndarray;
pub mod range;
pub mod tuple;
/// Convenience function to crash the program when types of arguments are not supported.
/// Used to be debugged with a stacktrace.
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
where
I: IntoIterator<Item = Type>,
{
unreachable!(
"unsupported types found '{}'",
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
)
}
#[derive(Debug, Clone, Copy)]
pub enum FloorOrCeil {
Floor,
Ceil,
}
#[derive(Debug, Clone, Copy)]
pub enum MinOrMax {
Min,
Max,
}
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64]
}
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.uint32, ctx.primitives.uint64]
}
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
}
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
]
}
#[derive(Debug, Clone, Copy)]
pub struct AnyObject<'ctx> {
pub ty: Type,
pub value: BasicValueEnum<'ctx>,
}
impl<'ctx> AnyObject<'ctx> {
/// Returns true if this object's type is a [`TypeEnum::TObj`] and has the object ID as `prim`.
pub fn is_obj(&self, ctx: &mut CodeGenContext<'ctx, '_>, prim: PrimDef) -> bool {
match &*ctx.unifier.get_ty(self.ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id == prim.id(),
_ => false,
}
}
/// Returns true if this object's type is a [`TypeEnum::TTuple`]
pub fn is_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
}
pub fn into_tuple() {}
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
}
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Int32> {
assert!(self.is_int32(ctx));
IntModel(Int32).believe_value(self.value.into_int_value())
}
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
}
pub fn is_int64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int64)
}
pub fn is_uint64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint64)
}
pub fn is_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
}
/// Returns true if the object type is `bool`, `int32`, `int64`, `uint32`, or `uint64`.
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, int_like(ctx))
}
/// Returns true if the object type is `int32`, `int64`.
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
}
/// Returns true if the object type is `uint32`, `uint64`.
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
}
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
assert!(self.is_int_like(ctx));
self.value.into_int_value()
}
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::Float)
}
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
assert!(self.is_float(ctx));
FloatModel(Float64).believe_value(self.value.into_float_value())
}
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::NDArray)
}
pub fn into_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
NDArrayObject::from_object(generator, ctx, *self)
}
/// Create an object from a boolean from an i1.
///
/// NOTE: In NAC3, booleans are i8. This function does converts the input i1 to an i8.
pub fn from_bool(ctx: &mut CodeGenContext<'ctx, '_>, n: Int<'ctx, Bool>) -> AnyObject<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let value = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
AnyObject { value: value.as_basic_value_enum(), ty: ctx.primitives.bool }
}
/// Helper function to compare two scalars.
///
/// Only int-to-int and float-to-float comparisons are allowed.
///
/// Panic otherwise.
pub fn compare_int_or_float_by_predicate<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: AnyObject<'ctx>,
rhs: AnyObject<'ctx>,
int_predicate: IntPredicate,
float_predicate: FloatPredicate,
name: &str,
) -> Int<'ctx, Bool> {
assert!(ctx.unifier.unioned(lhs.ty, rhs.ty), "lhs and rhs type should be the same");
let bool_model = IntModel(Bool);
let common_ty = lhs.ty;
let result = if lhs.is_float(ctx) {
let lhs = lhs.into_float(ctx);
let rhs = rhs.into_float(ctx);
ctx.builder.build_float_compare(float_predicate, lhs.value, rhs.value, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.into_int(ctx);
let rhs = rhs.into_int(ctx);
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else {
unsupported_type(ctx, [lhs.ty, rhs.ty]);
};
bool_model.check_value(generator, ctx.ctx, result).unwrap()
}
/// Helper function for `int32()`, `int64()`, `uint32()`, and `uint64()`.
///
/// TODO: Document me
fn cast_to_int_conversion<'a, G, HandleFloatFn>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_int_ty: Type,
handle_float: HandleFloatFn,
) -> AnyObject<'ctx>
where
G: CodeGenerator + ?Sized,
HandleFloatFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
{
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
let result = if self.is_float(ctx) {
// Handle float to int
let n = self.into_float(ctx);
handle_float(generator, ctx, n.value)
} else if self.is_int_like(ctx) {
// Handle int to a new int type
let n = self.into_int(ctx);
if n.get_type().get_bit_width() <= ret_int_ty_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_ty_llvm, "zext").unwrap()
} else {
ctx.builder.build_int_truncate(n, ret_int_ty_llvm, "trunc").unwrap()
}
} else {
unsupported_type(ctx, [self.ty]);
};
assert_eq!(ret_int_ty_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
AnyObject { value: result.into(), ty: ret_int_ty }
}
/// Call `int32()` on this object.
#[must_use]
pub fn call_int32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(
generator,
ctx,
ctx.primitives.int32,
|_generator, ctx, input| {
let n =
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
},
)
}
/// Call `int64()` on this object.
#[must_use]
pub fn call_int64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(
generator,
ctx,
ctx.primitives.int64,
|_generator, ctx, input| {
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
},
)
}
/// Call `uint32()` on this object.
#[must_use]
pub fn call_uint32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint32, |_generator, ctx, n| {
let n_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int32 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(
n_gez,
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
to_int32,
"conv",
)
.unwrap()
.into_int_value()
})
}
/// Call `uint64()` on this object.
#[must_use]
pub fn call_uint64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint64, |_generator, ctx, n| {
let val_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int64 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value()
})
}
// Get the `len()` of this object.
#[must_use]
pub fn call_len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
// TODO: Switch to returning SizeT
let result = match &*ctx.unifier.get_ty_immutable(self.ty) {
TypeEnum::TTuple { .. } => {
let tuple = TupleObject::from_object(ctx, *self);
tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32")
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{
let range = RangeObject::from_object(generator, ctx, *self);
range.len(generator, ctx)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, *self);
list.len(generator, ctx).truncate(generator, ctx, Int32, "list_len_i32")
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, *self);
ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32")
}
_ => unreachable!(),
};
AnyObject { ty: ctx.primitives.int32, value: result.value.as_basic_value_enum() }
}
/// Like [`AnyObject::call_bool`] but this returns an `Int<'ctx, Bool>` instead of an object.
pub fn bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
if self.is_int_like(ctx) {
let n = self.into_int(ctx);
let n = ctx
.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap();
bool_model.believe_value(n)
} else if self.is_float(ctx) {
let n = self.value.into_float_value();
let n = ctx
.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap();
bool_model.believe_value(n)
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `bool()` on this object.
#[must_use]
pub fn call_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let n = self.bool(ctx);
AnyObject::from_bool(ctx, n)
}
/// Call `float()` on this object.
#[must_use]
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let f64_model = FloatModel(Float64);
let llvm_f64 = ctx.ctx.f64_type();
let result = if self.is_float(ctx) {
self.into_float(ctx)
} else if self.is_signed_int(ctx) || self.is_bool(ctx) {
let n = self.into_int(ctx);
let n = ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap();
f64_model.believe_value(n)
} else if self.is_unsigned_int(ctx) {
let n = self.into_int(ctx);
let n = ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap();
f64_model.believe_value(n)
} else {
unsupported_type(ctx, [self.ty]);
};
AnyObject { ty: ctx.primitives.float, value: result.value.as_basic_value_enum() }
}
// Call `abs()` on this object.
#[must_use]
pub fn call_abs<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
if self.is_float(ctx) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
AnyObject { value: n.into(), ty: ctx.primitives.float }
} else if self.is_unsigned_int(ctx) || self.is_signed_int(ctx) {
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = self.value.into_int_value();
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
AnyObject { value: n.into(), ty: self.ty }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
// Call `round()` on this object.
//
// It is possible to specify which kind of int type to return.
#[must_use]
pub fn call_round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ret_int_ty: Type,
) -> AnyObject<'ctx> {
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
let result = if ctx.unifier.unioned(self.ty, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.ty])
};
AnyObject { ty: ret_int_ty, value: result.as_basic_value_enum() }
}
/// Call `np_round()` on this object.
///
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
#[must_use]
pub fn call_np_round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
if self.is_float(ctx) {
let n = self.into_float(ctx);
let n = llvm_intrinsics::call_float_rint(ctx, n.value, None);
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `min()` or `max()` on two objects.
#[must_use]
pub fn call_min_or_max(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
a: AnyObject<'ctx>,
b: AnyObject<'ctx>,
) -> AnyObject<'ctx> {
if !ctx.unifier.unioned(a.ty, b.ty) {
unsupported_type(ctx, [a.ty, b.ty])
}
let common_ty = a.ty;
if a.is_float(ctx) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
};
let a = a.into_float(ctx).value;
let b = b.into_float(ctx).value;
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: ctx.primitives.float }
} else if a.is_unsigned_int(ctx) || a.is_bool(ctx) {
// Treating bool has an unsigned int since that is convenient
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax,
};
let a = a.into_int(ctx);
let b = b.into_int(ctx);
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
} else if a.is_signed_int(ctx) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_smin,
MinOrMax::Max => llvm_intrinsics::call_int_smax,
};
let a = a.into_int(ctx);
let b = b.into_int(ctx);
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
} else {
unsupported_type(ctx, [common_ty])
}
}
/// Call `floor()` or `ceil()` on this object.
///
/// It is possible to specify which kind of int type to return.
#[must_use]
pub fn call_floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
ret_int_ty: Type,
) -> Self {
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
if self.is_float(ctx) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.into_float(ctx).value;
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "").unwrap();
AnyObject { ty: ret_int_ty, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `np_floor()` or `np_ceil()` on this object.
#[must_use]
pub fn call_np_floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
) -> Self {
// TODO:
if self.is_float(ctx) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.into_float(ctx).value;
let n = function(ctx, n, None);
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ctx.primitives.float },
|generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
}

View File

@ -1,176 +0,0 @@
use super::NDArrayObject;
use crate::{
codegen::{
irrt::{call_nac3_array_set_and_validate_list_shape, call_nac3_array_write_list_to_array},
model::*,
object::{list::ListObject, AnyObject},
stmt::gen_if_else_expr_callback,
CodeGenContext, CodeGenerator,
},
toplevel::helper::{arraylike_flatten_element_type, arraylike_get_ndims},
typecheck::typedef::{Type, TypeEnum},
};
fn get_list_object_dtype_and_ndims<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> (Type, u64) {
let dtype = arraylike_flatten_element_type(&mut ctx.unifier, list.item_type);
let ndims = arraylike_get_ndims(&mut ctx.unifier, list.item_type);
let ndims = ndims + 1; // To count `list` itself.
(dtype, ndims)
}
impl<'ctx> NDArrayObject<'ctx> {
fn from_np_array_list_copy_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
let sizet_model = IntModel(SizeT);
let (dtype, ndims_int) = get_list_object_dtype_and_ndims(ctx, list);
let list_value = list.get_opaque_list_ptr(generator, ctx);
// Validate `list` has a consistent shape.
// Raise an exception if `list` is something abnormal like `[[1, 2], [3]]`.
// If `list` has a consistent shape, deduce the shape and write it to `shape`.
let ndims = sizet_model.constant(generator, ctx.ctx, ndims_int);
let shape = sizet_model.array_alloca(generator, ctx, ndims.value, "shape");
call_nac3_array_set_and_validate_list_shape(generator, ctx, list_value, ndims, shape);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims_int, "ndarray_from_list");
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
// Copy all contents from the list.
call_nac3_array_write_list_to_array(generator, ctx, list_value, ndarray.instance);
ndarray
}
fn from_np_array_list_try_no_copy_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
) -> Self {
// np_array without copy is only possible `list` is not nested.
// If `list` is `list[T]`, we can create an ndarray with `data` set
// to the array pointer of `list`.
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
if ndims == 1 {
// `list` is not nested, does not need to copy.
let ndarray =
NDArrayObject::alloca(generator, ctx, dtype, 1, "ndarray_from_list_no_copy");
// Set data
let data = list.get_opaque_items_ptr(generator, ctx);
ndarray.instance.set(ctx, |f| f.data, data);
// Set shape
// dim = list->len;
// shape[0] = dim;
let shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
let dim = list.instance.get(generator, ctx, |f| f.len, "dim");
shape.offset(generator, ctx, zero.value, "pdim").store(ctx, dim);
// Set strides, the `data` is contiguous
ndarray.update_strides_by_shape(generator, ctx);
// Done
ndarray
} else {
// `list` is nested, it is impossible to not copy.
NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list)
}
}
fn from_np_array_list_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
list: ListObject<'ctx>,
copy: Int<'ctx, Bool>,
) -> Self {
let (dtype, ndims) = get_list_object_dtype_and_ndims(ctx, list);
let ndarray = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
|generator, ctx| {
let ndarray =
NDArrayObject::from_np_array_list_try_no_copy_impl(generator, ctx, list);
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
}
pub fn from_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
copy: Int<'ctx, Bool>,
) -> Self {
let ndarray_val = gen_if_else_expr_callback(
generator,
ctx,
|_generator, _ctx| Ok(copy.value),
|generator, ctx| {
let ndarray = ndarray.make_copy(generator, ctx, "np_array_copied_ndarray"); // Force copy
Ok(Some(ndarray.instance.value))
},
|_generator, _ctx| {
// No need to copy. Return `ndarray` itself.
Ok(Some(ndarray.instance.value))
},
)
.unwrap()
.unwrap();
NDArrayObject::from_value_and_unpacked_types(
generator,
ctx,
ndarray_val,
ndarray.dtype,
ndarray.ndims,
)
}
pub fn from_np_array<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
copy: Int<'ctx, Bool>,
) -> Self {
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, object);
NDArrayObject::from_np_array_list_impl(generator, ctx, list, copy)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, object);
NDArrayObject::from_np_array_ndarray_impl(generator, ctx, ndarray, copy)
}
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
}
}
}

View File

@ -1,161 +0,0 @@
use itertools::Itertools;
use crate::codegen::{
irrt::{
call_nac3_ndarray_broadcast_shapes, call_nac3_ndarray_broadcast_to,
call_nac3_ndarray_util_assert_shape_no_negative,
},
model::*,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
/// Fields of [`ShapeEntry`]
pub struct ShapeEntryFields<'ctx, F: FieldTraversal<'ctx>> {
pub ndims: F::Out<IntModel<SizeT>>,
pub shape: F::Out<PtrModel<IntModel<SizeT>>>,
}
/// An IRRT structure used in broadcasting.
#[derive(Debug, Clone, Copy, Default)]
pub struct ShapeEntry;
impl<'ctx> StructKind<'ctx> for ShapeEntry {
type Fields<F: FieldTraversal<'ctx>> = ShapeEntryFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { ndims: traversal.add_auto("ndims"), shape: traversal.add_auto("shape") }
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create a broadcast view on this ndarray with a target shape.
///
/// The input shape will be checked to make sure that it contains no negative values.
///
/// * `target_ndims` - The ndims type after broadcasting to the given shape.
/// The caller has to figure this out for this function.
/// * `target_shape` - An array pointer pointing to the target shape.
#[must_use]
pub fn broadcast_to<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target_ndims: u64,
target_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let target_ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, target_ndims);
call_nac3_ndarray_util_assert_shape_no_negative(
generator,
ctx,
target_ndims_llvm,
target_shape,
);
let broadcast_ndarray = NDArrayObject::alloca(
generator,
ctx,
self.dtype,
target_ndims,
"broadcast_ndarray_to_dst",
);
broadcast_ndarray.copy_shape_from_array(generator, ctx, target_shape);
call_nac3_ndarray_broadcast_to(generator, ctx, self.instance, broadcast_ndarray.instance);
broadcast_ndarray
}
}
/// A result produced by [`broadcast_all_ndarrays`]
#[derive(Debug, Clone)]
pub struct BroadcastAllResult<'ctx> {
/// The statically known `ndims` of the broadcast result.
pub ndims: u64,
/// The broadcasting shape.
pub shape: Ptr<'ctx, IntModel<SizeT>>,
/// Broadcasted views on the inputs.
///
/// All of them will have `shape` [`BroadcastAllResult::shape`] and
/// `ndims` [`BroadcastAllResult::ndims`]. The length of the vector
/// is the same as the input.
pub ndarrays: Vec<NDArrayObject<'ctx>>,
}
pub fn broadcast_shapes<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
in_entries: &[(Ptr<'ctx, IntModel<SizeT>>, u64)],
broadcast_ndims: u64,
broadcast_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let sizet_model = IntModel(SizeT);
let shape_model = StructModel(ShapeEntry);
// Prepare input shape entries
let num_shape_entries =
sizet_model.constant(generator, ctx.ctx, u64::try_from(in_entries.len()).unwrap());
let shape_entries =
shape_model.array_alloca(generator, ctx, num_shape_entries.value, "shape_entries");
for (i, (in_shape, in_ndims)) in in_entries.iter().enumerate() {
let i = sizet_model.constant(generator, ctx.ctx, i as u64).value;
let pshape_entry = shape_entries.offset(generator, ctx, i, "shape_entry");
let in_ndims = sizet_model.constant(generator, ctx.ctx, *in_ndims);
pshape_entry.set(ctx, |f| f.ndims, in_ndims);
pshape_entry.set(ctx, |f| f.shape, *in_shape);
}
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
call_nac3_ndarray_broadcast_shapes(
generator,
ctx,
num_shape_entries,
shape_entries,
broadcast_ndims,
broadcast_shape,
);
}
impl<'ctx> NDArrayObject<'ctx> {
// TODO: DOCUMENT: Behaves like `np.broadcast()`, except returns results differently.
pub fn broadcast<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarrays: &[Self],
) -> BroadcastAllResult<'ctx> {
assert!(!ndarrays.is_empty());
let sizet_model = IntModel(SizeT);
// Infer the broadcast output ndims.
let broadcast_ndims_int = ndarrays.iter().map(|ndarray| ndarray.ndims).max().unwrap();
let broadcast_ndims = sizet_model.constant(generator, ctx.ctx, broadcast_ndims_int);
let broadcast_shape =
sizet_model.array_alloca(generator, ctx, broadcast_ndims.value, "broadcast_shape");
let shape_entries = ndarrays
.iter()
.map(|ndarray| {
(ndarray.instance.get(generator, ctx, |f| f.shape, "shape"), ndarray.ndims)
})
.collect_vec();
broadcast_shapes(generator, ctx, &shape_entries, broadcast_ndims_int, broadcast_shape);
// Broadcast all the inputs to shape `dst_shape`.
let broadcast_ndarrays: Vec<_> = ndarrays
.iter()
.map(|ndarray| {
ndarray.broadcast_to(generator, ctx, broadcast_ndims_int, broadcast_shape)
})
.collect_vec();
BroadcastAllResult {
ndims: broadcast_ndims_int,
shape: broadcast_shape,
ndarrays: broadcast_ndarrays,
}
}
}

View File

@ -1,238 +0,0 @@
use inkwell::{values::BasicValueEnum, IntPredicate};
use super::NDArrayObject;
use crate::{
codegen::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, object::AnyObject,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
/// Get the zero value in `np.zeros()` of a `dtype`.
fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i32_type().const_zero().into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
ctx.ctx.i64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_zero().into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "").value.into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
/// Get the one value in `np.ones()` of a `dtype`.
fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
) -> BasicValueEnum<'ctx> {
if [ctx.primitives.int32, ctx.primitives.uint32]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32);
ctx.ctx.i32_type().const_int(1, is_signed).into()
} else if [ctx.primitives.int64, ctx.primitives.uint64]
.iter()
.any(|ty| ctx.unifier.unioned(dtype, *ty))
{
let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64);
ctx.ctx.i64_type().const_int(1, is_signed).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.float) {
ctx.ctx.f64_type().const_float(1.0).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.bool) {
ctx.ctx.bool_type().const_int(1, false).into()
} else if ctx.unifier.unioned(dtype, ctx.primitives.str) {
ctx.gen_string(generator, "1").value.into()
} else {
panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype));
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create an ndarray like `np.empty`.
pub fn from_np_empty<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
// Validate `shape`
// TODO: Should the caller be responsible for this instead?
let ndims_llvm = IntModel(SizeT).constant(generator, ctx.ctx, ndims);
call_nac3_ndarray_util_assert_shape_no_negative(generator, ctx, ndims_llvm, shape);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "full_ndarray");
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.create_data(generator, ctx);
ndarray
}
/// Create an ndarray like `np.full`.
pub fn from_np_full<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
fill_value: AnyObject<'ctx>,
) -> Self {
// Sanity check on `fill_value`'s dtype.
assert!(ctx.unifier.unioned(dtype, fill_value.ty));
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
ndarray.fill(generator, ctx, fill_value);
ndarray
}
/// Create an ndarray like `np.zero`.
pub fn from_np_zero<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_zero_value(generator, ctx, dtype);
let fill_value = AnyObject { value: fill_value, ty: dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.ones`.
pub fn from_np_ones<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_one_value(generator, ctx, dtype);
let fill_value = AnyObject { value: fill_value, ty: dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
/// Create an ndarray like `np.arange`.
/// The returned ndarray's `dtype` is always `float`
pub fn from_np_arange<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
length: Int<'ctx, SizeT>,
) -> Self {
let ndarray = NDArrayObject::alloca(
generator,
ctx,
ctx.primitives.float,
1, // ndims = 1
"arange_ndarray",
);
// `ndarray.shape[0] = length`
ndarray
.instance
.get(generator, ctx, |f| f.shape, "shape")
.offset_const(generator, ctx, 0, "dim")
.store(ctx, length);
// Create data and set elements
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// Get the index of the current element, convert that index to float, and write it.
// This is how we get [0.0, 1.0, 2.0, ...].
let index = nditer.get_index(generator, ctx);
let pelement = nditer.get_pointer(generator, ctx);
let val = ctx
.builder
.build_unsigned_int_to_float(index.value, ctx.ctx.f64_type(), "val")
.unwrap();
ctx.builder.build_store(pelement, val).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.eye`.
pub fn from_np_eye<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
nrows: Int<'ctx, SizeT>,
ncols: Int<'ctx, SizeT>,
offset: Int<'ctx, SizeT>,
) -> Self {
let ndzero = ndarray_zero_value(generator, ctx, dtype);
let ndone = ndarray_one_value(generator, ctx, dtype);
let ndarray = NDArrayObject::alloca_dynamic_shape(
generator,
ctx,
dtype,
&[nrows, ncols],
"eye_ndarray",
);
// Create data and make the matrix like look np.eye()
ndarray.create_data(generator, ctx);
ndarray
.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
// NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero
// and this loop would not execute.
// Load up `row_i` and `col_i` from indices.
let row_i = nditer
.get_indices()
.offset_const(generator, ctx, 0, "")
.load(generator, ctx, "row_i");
let col_i = nditer
.get_indices()
.offset_const(generator, ctx, 1, "")
.load(generator, ctx, "col_i");
// Write to element
let be_one =
row_i.add(ctx, offset, "").compare(ctx, IntPredicate::EQ, col_i, "write_one");
let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap();
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, value).unwrap();
Ok(())
})
.unwrap();
ndarray
}
/// Create an ndarray like `np.identity`.
pub fn from_np_identity<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
size: Int<'ctx, SizeT>,
) -> Self {
// Convenient implementation
let offset = IntModel(SizeT).const_0(generator, ctx.ctx);
NDArrayObject::from_np_eye(generator, ctx, dtype, size, size, offset)
}
}

View File

@ -1,128 +0,0 @@
use inkwell::{FloatPredicate, IntPredicate};
use crate::codegen::{
model::*,
object::{AnyObject, MinOrMax},
stmt::gen_if_callback,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
impl<'ctx> NDArrayObject<'ctx> {
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
///
/// Generate LLVM IR to find the extremum and index of the **first** extremum value.
///
/// Care has also been taken to make the error messages match that of NumPy.
fn min_max_argmin_argmax_helper<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
on_empty_err_msg: &str,
) -> (AnyObject<'ctx>, Int<'ctx, SizeT>) {
let sizet_model = IntModel(SizeT);
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
// If the ndarray is empty, throw an error.
let is_empty = self.is_empty(generator, ctx);
ctx.make_assert(
generator,
is_empty.value,
"0:ValueError",
on_empty_err_msg,
[None, None, None],
ctx.current_loc,
);
// Setup and initialize the extremum to be the first element in the ndarray
let pextremum_index = sizet_model.alloca(generator, ctx, "extremum_index");
let pextremum = ctx.builder.build_alloca(dtype_llvm, "extremum").unwrap();
let zero = sizet_model.const_0(generator, ctx.ctx);
pextremum_index.store(ctx, zero);
let first_scalar = self.get_nth_scalar(generator, ctx, zero);
ctx.builder.build_store(pextremum, first_scalar.value).unwrap();
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = AnyObject { ty: self.dtype, value: old_extremum };
let scalar = nditer.get_scalar(generator, ctx);
let new_extremum = AnyObject::call_min_or_max(ctx, kind, old_extremum, scalar);
gen_if_callback(
generator,
ctx,
|generator, ctx| {
// Is new_extremum is more extreme than old_extremum?
let cmp = AnyObject::compare_int_or_float_by_predicate(
generator,
ctx,
new_extremum,
old_extremum,
IntPredicate::NE,
FloatPredicate::ONE,
"",
);
Ok(cmp.value)
},
|generator, ctx| {
// Yes, update the extremum index
let index = nditer.get_index(generator, ctx);
pextremum_index.store(ctx, index);
Ok(())
},
|_generator, _ctx| {
// No, do nothing
Ok(())
},
)
})
.unwrap();
// Finally return the extremum and extremum index.
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = AnyObject { ty: self.dtype, value: extremum };
(extremum, extremum_index)
}
/// Invoke NAC3's builtin `np_min()` or `np_max()`.
pub fn min_or_max<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> AnyObject<'ctx> {
let on_empty_err_msg = format!(
"zero-size array to reduction operation {} which has no identity",
match kind {
MinOrMax::Min => "minimum",
MinOrMax::Max => "maximum",
}
);
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).0
}
/// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`.
pub fn argmin_or_argmax<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> Int<'ctx, SizeT> {
let on_empty_err_msg = format!(
"attempt to get {} of an empty sequence",
match kind {
MinOrMax::Min => "argmin",
MinOrMax::Max => "argmax",
}
);
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).1
}
}

View File

@ -1,321 +0,0 @@
use crate::codegen::{irrt::call_nac3_ndarray_index, model::*, CodeGenContext, CodeGenerator};
use super::NDArrayObject;
pub type NDIndexType = Byte;
/// Fields of [`NDIndex`]
#[derive(Debug, Clone, Copy)]
pub struct NDIndexFields<'ctx, F: FieldTraversal<'ctx>> {
pub type_: F::Out<IntModel<NDIndexType>>, // Defined to be uint8_t in IRRT
pub data: F::Out<PtrModel<IntModel<Byte>>>,
}
/// An IRRT representation fo an ndarray subscript index.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct NDIndex;
impl<'ctx> StructKind<'ctx> for NDIndex {
type Fields<F: FieldTraversal<'ctx>> = NDIndexFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields { type_: traversal.add_auto("type"), data: traversal.add_auto("data") }
}
}
/// Fields of [`UserSlice`]
#[derive(Debug, Clone)]
pub struct UserSliceFields<'ctx, F: FieldTraversal<'ctx>> {
pub start_defined: F::Out<IntModel<Bool>>,
pub start: F::Out<IntModel<Int32>>,
pub stop_defined: F::Out<IntModel<Bool>>,
pub stop: F::Out<IntModel<Int32>>,
pub step_defined: F::Out<IntModel<Bool>>,
pub step: F::Out<IntModel<Int32>>,
}
/// An IRRT representation of a user slice.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct UserSlice;
impl<'ctx> StructKind<'ctx> for UserSlice {
type Fields<F: FieldTraversal<'ctx>> = UserSliceFields<'ctx, F>;
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
Self::Fields {
start_defined: traversal.add_auto("start_defined"),
start: traversal.add_auto("start"),
stop_defined: traversal.add_auto("stop_defined"),
stop: traversal.add_auto("stop"),
step_defined: traversal.add_auto("step_defined"),
step: traversal.add_auto("step"),
}
}
}
/// A convenience structure to prepare a [`UserSlice`].
#[derive(Debug, Clone)]
pub struct RustUserSlice<'ctx> {
pub start: Option<Int<'ctx, Int32>>,
pub stop: Option<Int<'ctx, Int32>>,
pub step: Option<Int<'ctx, Int32>>,
}
impl<'ctx> RustUserSlice<'ctx> {
/// Write the contents to an LLVM [`UserSlice`].
pub fn write_to_user_slice<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_slice_ptr: Ptr<'ctx, StructModel<UserSlice>>,
) {
let bool_model = IntModel(Bool);
let false_ = bool_model.constant(generator, ctx.ctx, 0);
let true_ = bool_model.constant(generator, ctx.ctx, 1);
// TODO: Code duplication. Probably okay...?
match self.start {
Some(start) => {
dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.start).store(ctx, start);
}
None => dst_slice_ptr.gep(ctx, |f| f.start_defined).store(ctx, false_),
}
match self.stop {
Some(stop) => {
dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.stop).store(ctx, stop);
}
None => dst_slice_ptr.gep(ctx, |f| f.stop_defined).store(ctx, false_),
}
match self.step {
Some(step) => {
dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, true_);
dst_slice_ptr.gep(ctx, |f| f.step).store(ctx, step);
}
None => dst_slice_ptr.gep(ctx, |f| f.step_defined).store(ctx, false_),
}
}
}
// A convenience enum variant to store the content and type of an NDIndex in high level.
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(Int<'ctx, Int32>), // TODO: To be SizeT
Slice(RustUserSlice<'ctx>),
NewAxis,
Ellipsis,
}
impl<'ctx> RustNDIndex<'ctx> {
/// Get the value to set `NDIndex::type` for this variant.
fn get_type_id(&self) -> u64 {
// Defined in IRRT, must be in sync
match self {
RustNDIndex::SingleElement(_) => 0,
RustNDIndex::Slice(_) => 1,
RustNDIndex::NewAxis => 2,
RustNDIndex::Ellipsis => 3,
}
}
/// Write the contents to an LLVM [`NDIndex`].
fn write_to_ndindex<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
dst_ndindex_ptr: Ptr<'ctx, StructModel<NDIndex>>,
) {
let ndindex_type_model = IntModel(NDIndexType::default());
let i32_model = IntModel(Int32);
let user_slice_model = StructModel(UserSlice);
// Set `dst_ndindex_ptr->type`
dst_ndindex_ptr
.gep(ctx, |f| f.type_)
.store(ctx, ndindex_type_model.constant(generator, ctx.ctx, self.get_type_id()));
// Set `dst_ndindex_ptr->data`
match self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = i32_model.alloca(generator, ctx, "index");
index_ptr.store(ctx, *in_index);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, index_ptr.pointer_cast(generator, ctx, IntModel(Byte), ""));
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr = user_slice_model.alloca(generator, ctx, "user_slice");
in_rust_slice.write_to_user_slice(generator, ctx, user_slice_ptr);
dst_ndindex_ptr
.gep(ctx, |f| f.data)
.store(ctx, user_slice_ptr.pointer_cast(generator, ctx, IntModel(Byte), ""));
}
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
}
}
/// Allocate an array of `NDIndex`es on the stack and return its stack pointer.
pub fn alloca_ndindexes<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
in_ndindexes: &[RustNDIndex<'ctx>],
) -> (Int<'ctx, SizeT>, Ptr<'ctx, StructModel<NDIndex>>) {
let sizet_model = IntModel(SizeT);
let ndindex_model = StructModel(NDIndex);
let num_ndindexes = sizet_model.constant(generator, ctx.ctx, in_ndindexes.len() as u64);
let ndindexes =
ndindex_model.array_alloca(generator, ctx, num_ndindexes.value, "ndindexes");
for (i, in_ndindex) in in_ndindexes.iter().enumerate() {
let i = sizet_model.constant(generator, ctx.ctx, i as u64);
let pndindex = ndindexes.offset(generator, ctx, i.value, "");
in_ndindex.write_to_ndindex(generator, ctx, pndindex);
}
(num_ndindexes, ndindexes)
}
}
impl<'ctx> NDArrayObject<'ctx> {
/// Get the ndims [`Type`] after indexing with a given slice.
#[must_use]
pub fn deduce_ndims_after_indexing_with(&self, indexes: &[RustNDIndex<'ctx>]) -> u64 {
let mut ndims = self.ndims;
for index in indexes {
match index {
RustNDIndex::SingleElement(_) => {
ndims -= 1; // Single elements decrements ndims
}
RustNDIndex::NewAxis => {
ndims += 1; // `np.newaxis` / `none` adds a new axis
}
RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {}
}
}
ndims
}
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
///
/// This function behaves like NumPy's ndarray indexing, but if the indexes index
/// into a single element, an unsized ndarray is returned.
#[must_use]
pub fn index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indexes: &[RustNDIndex<'ctx>],
name: &str,
) -> Self {
let dst_ndims = self.deduce_ndims_after_indexing_with(indexes);
let dst_ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, dst_ndims, name);
let (num_indexes, indexes) = RustNDIndex::alloca_ndindexes(generator, ctx, indexes);
call_nac3_ndarray_index(
generator,
ctx,
num_indexes,
indexes,
self.instance,
dst_ndarray.instance,
);
dst_ndarray
}
}
pub mod util {
use itertools::Itertools;
use nac3parser::ast::{Constant, Expr, ExprKind};
use crate::{
codegen::{expr::gen_slice, model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::{RustNDIndex, RustUserSlice};
/// Generate LLVM code to transform an ndarray subscript expression to
/// its list of [`RustNDIndex`]
///
/// i.e.,
/// ```python
/// my_ndarray[::3, 1, :2:]
/// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es
/// ```
pub fn gen_ndarray_subscript_ndindexes<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
subscript: &Expr<Option<Type>>,
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools
let i32_model = IntModel(Int32);
// Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let index_exprs = match &subscript.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![subscript],
};
// Process all index expressions
let mut rust_ndindexes: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
for index_expr in index_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
// Handle slices
// Helper function here to deduce code duplication
let (lower, upper, step) = gen_slice(generator, ctx, lower, upper, step)?;
RustNDIndex::Slice(RustUserSlice { start: lower, stop: upper, step })
} else if let ExprKind::Constant { value: Constant::Ellipsis, .. } = &index_expr.node {
// Handle '...'
RustNDIndex::Ellipsis
} else {
match &*ctx.unifier.get_ty(index_expr.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{
// Handle `np.newaxis` / `None`
RustNDIndex::NewAxis
}
_ => {
// Treat and handle everything else as a single element index.
let index =
generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
ctx,
generator,
ctx.primitives.int32, // Must be int32, this checks for illegal values
)?;
let index = i32_model.check_value(generator, ctx.ctx, index).unwrap();
RustNDIndex::SingleElement(index)
}
}
};
rust_ndindexes.push(ndindex);
}
Ok(rust_ndindexes)
}
}

View File

@ -1,200 +0,0 @@
use itertools::Itertools;
use crate::{
codegen::{
object::ndarray::{AnyObject, NDArrayObject},
stmt::gen_for_callback,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
use super::{nditer::NDIterHandle, scalar::ScalarOrNDArray, NDArrayOut};
impl<'ctx> NDArrayObject<'ctx> {
/// TODO: Document me. Has complex behavior.
/// and explain why `ret_dtype` has to be specified beforehand.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarrays: &[Self],
out: NDArrayOut<'ctx>,
mapping: MappingFn,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&[AnyObject<'ctx>],
) -> Result<AnyObject<'ctx>, String>,
{
// Broadcast inputs
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
let out_ndarray = match out {
NDArrayOut::NewNDArray { dtype } => {
// Create a new ndarray based on the broadcast shape.
let result_ndarray = NDArrayObject::alloca(
generator,
ctx,
dtype,
broadcast_result.ndims,
"mapped_ndarray",
);
result_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
result_ndarray.create_data(generator, ctx);
result_ndarray
}
NDArrayOut::WriteToNDArray { ndarray: result_ndarray } => {
// Use an existing ndarray.
// Check that its shape is compatible with the broadcast shape.
result_ndarray.check_can_be_written_by_out(
generator,
ctx,
broadcast_result.ndims,
broadcast_result.shape,
);
result_ndarray
}
};
// Map element-wise and store results into `mapped_ndarray`.
let nditer = NDIterHandle::new(generator, ctx, out_ndarray);
gen_for_callback(
generator,
ctx,
Some("broadcast_starmap"),
|generator, ctx| {
// Create NDIters for all broadcasted input ndarrays.
let other_nditers = broadcast_result
.ndarrays
.iter()
.map(|ndarray| NDIterHandle::new(generator, ctx, *ndarray))
.collect_vec();
Ok((nditer, other_nditers))
},
|generator, ctx, (out_nditer, _in_nditers)| {
// We can simply use `out_nditer`'s `has_next()`.
// `in_nditers`' `has_next()`s should return the same value.
Ok(out_nditer.has_next(generator, ctx).value)
},
|generator, ctx, _hooks, (out_nditer, in_nditers)| {
// Get all the scalars from the broadcasted input ndarrays, pass them to `mapping`,
// and write to `out_ndarray`.
let in_scalars =
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
let result = mapping(generator, ctx, &in_scalars)?;
// Sanity check on result's ty
assert!(ctx.unifier.unioned(result.ty, out_ndarray.dtype));
let p = out_nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, result.value).unwrap();
Ok(())
},
|generator, ctx, (out_nditer, in_nditers)| {
// Advance all iterators
out_nditer.next(generator, ctx);
in_nditers.iter().for_each(|nditer| nditer.next(generator, ctx));
Ok(())
},
)?;
Ok(out_ndarray)
}
pub fn map<'a, G, Mapping>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
out: NDArrayOut<'ctx>,
mapping: Mapping,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
AnyObject<'ctx>,
) -> Result<AnyObject<'ctx>, String>,
{
NDArrayObject::broadcasting_starmap(
generator,
ctx,
&[*self],
out,
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
)
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// TODO: Document me. Has complex behavior.
/// and explain why `ret_dtype` has to be specified beforehand.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &[Self],
ret_dtype: Type,
mapping: MappingFn,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&[AnyObject<'ctx>],
) -> Result<AnyObject<'ctx>, String>,
{
// Check if all inputs are AnyObjects
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
if let Some(scalars) = all_scalars {
let scalar = mapping(generator, ctx, &scalars)?;
// Sanity check on scalar's type
assert!(ctx.unifier.unioned(scalar.ty, ret_dtype));
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.
let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
let ndarray = NDArrayObject::broadcasting_starmap(
generator,
ctx,
&inputs,
NDArrayOut::NewNDArray { dtype: ret_dtype },
mapping,
)?;
Ok(ScalarOrNDArray::NDArray(ndarray))
}
}
pub fn map<'a, G, Mapping>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: Mapping,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
AnyObject<'ctx>,
) -> Result<AnyObject<'ctx>, String>,
{
ScalarOrNDArray::broadcasting_starmap(
generator,
ctx,
&[*self],
ret_dtype,
|generator, ctx, scalars| mapping(generator, ctx, scalars[0]),
)
}
}

View File

@ -1,831 +0,0 @@
pub mod array;
pub mod broadcast;
pub mod factory;
pub mod functions;
pub mod indexing;
pub mod mapping;
pub mod nalgebra;
pub mod nditer;
pub mod product;
pub mod scalar;
pub mod shape_util;
use crate::{
codegen::{
irrt::{
call_nac3_ndarray_copy_data, call_nac3_ndarray_get_nth_pelement,
call_nac3_ndarray_is_c_contiguous, call_nac3_ndarray_len, call_nac3_ndarray_nbytes,
call_nac3_ndarray_resolve_and_check_new_shape, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size, call_nac3_ndarray_transpose,
call_nac3_ndarray_util_assert_output_shape_same,
},
model::*,
stmt::{gen_for_callback, BreakContinueHooks},
structure::{NDArray, SimpleNDArray},
CodeGenContext, CodeGenerator,
},
toplevel::{
helper::{create_ndims, extract_ndims},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typecheck::typedef::Type,
};
use indexing::RustNDIndex;
use inkwell::{
context::Context,
types::BasicType,
values::{BasicValue, PointerValue},
AddressSpace, IntPredicate,
};
use nditer::NDIterHandle;
use scalar::ScalarOrNDArray;
use util::call_memcpy_model;
use super::{tuple::TupleObject, AnyObject};
/// A NAC3 Python ndarray object.
#[derive(Debug, Clone, Copy)]
pub struct NDArrayObject<'ctx> {
pub dtype: Type,
pub ndims: u64,
pub instance: Ptr<'ctx, StructModel<NDArray>>,
}
impl<'ctx> NDArrayObject<'ctx> {
/// Create an [`NDArrayObject`] from an LLVM value and its typechecker [`Type`].
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
}
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
/// `dtype` and `ndims`.
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: V,
dtype: Type,
ndims: u64,
) -> Self {
let pndarray_model = PtrModel(StructModel(NDArray));
let value = pndarray_model.check_value(generator, ctx.ctx, value).unwrap();
NDArrayObject { dtype, ndims, instance: value }
}
/// Forget that this is an ndarray and convert to an [`AnyObject`].
pub fn to_any_object(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let ty = self.get_ndarray_type(ctx);
AnyObject { value: self.instance.value.as_basic_value_enum(), ty }
}
/// Create a [`SimpleNDArray`] from the contents of this ndarray.
///
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
///
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of
/// the returned [`SimpleNDArray`] and copy contents of this ndarray to there.
///
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`SimpleNDArray`]
/// will have the same `data` field as this ndarray.
///
/// The `item_model` sets the [`Model`] of the returned [`SimpleNDArray`]'s `Item` model, and should match the
/// `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics.
pub fn make_simple_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
item_model: Item,
name: &str,
) -> Ptr<'ctx, StructModel<SimpleNDArray<Item>>> {
// Sanity check on `self.dtype` and `item_model`.
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
item_model.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
let simple_ndarray_model = StructModel(SimpleNDArray { item: item_model });
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
// Allocate and setup the resulting [`SimpleNDArray`].
let result = simple_ndarray_model.alloca(generator, ctx, name);
// Set ndims and shape.
let ndims = self.get_ndims(generator, ctx.ctx);
result.set(ctx, |f| f.ndims, ndims);
let shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
result.set(ctx, |f| f.shape, shape);
// Set data, but we do things differently if this ndarray is contiguous.
let is_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb; This ndarray is contiguous.
let data = self.instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb; This ndarray is not contiguous. Do a full-copy on `data`.
// TODO: Reimplement this? This method does give us the contiguous `data`, but
// this creates a few extra bytes of useless information because an entire NDArray
// is allocated. Though this is super convenient.
let data = self.make_copy(generator, ctx, "").instance.get(generator, ctx, |f| f.data, "");
let data = data.pointer_cast(generator, ctx, item_model, "");
result.set(ctx, |f| f.data, data);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition to end_bb for continuation
ctx.builder.position_at_end(end_bb);
result
}
/// Create an [`NDArrayObject`] from a [`SimpleNDArray`].
///
/// This operation is super cheap. The newly created [`NDArray`] will not copy contents
/// from `simple_ndarray`, but only having its `data` and `shape` pointing to `simple_array`.
pub fn from_simple_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
simple_ndarray: Ptr<'ctx, StructModel<SimpleNDArray<Item>>>,
dtype: Type,
ndims: u64,
) -> Self {
// Sanity check on `dtype` and `simple_array`'s `Item` model.
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
simple_ndarray.model.0 .0.item.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
let byte_model = IntModel(Byte);
// TODO: Check if `ndims` is consistent with that in `simple_array`?
// Allocate the resulting ndarray.
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims, "from_simple_ndarray");
// Set data, shape by simply copying addresses.
let data = simple_ndarray
.get(generator, ctx, |f| f.data, "")
.pointer_cast(generator, ctx, byte_model, "data");
ndarray.instance.set(ctx, |f| f.data, data);
let shape = simple_ndarray.get(generator, ctx, |f| f.shape, "shape");
ndarray.instance.set(ctx, |f| f.shape, shape);
// Set strides. We know `data` is contiguous.
ndarray.update_strides_by_shape(generator, ctx);
ndarray
}
/// Get the typechecker ndarray type of this [`NDArrayObject`].
pub fn get_ndarray_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type {
let ndims = create_ndims(&mut ctx.unifier, self.ndims);
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims))
}
/// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_size(generator, ctx, self.instance)
}
/// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_nbytes(generator, ctx, self.instance)
}
/// Get the `len()` of this ndarray.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
call_nac3_ndarray_len(generator, ctx, self.instance)
}
/// Check if this ndarray is C-contiguous.
///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_ndarray_is_c_contiguous(generator, ctx, self.instance)
}
/// Get the pointer to the n-th (0-based) element.
///
/// The returned pointer has the element type of the LLVM type of this ndarray's `dtype`.
///
/// There is no out-of-bounds check.
pub fn get_nth_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
name: &str,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.dtype);
let p = call_nac3_ndarray_get_nth_pelement(generator, ctx, self.instance, nth);
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), name)
.unwrap()
}
/// Get the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn get_nth_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
) -> AnyObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap();
AnyObject { ty: self.dtype, value }
}
/// Set the n-th (0-based) scalar.
///
/// There is no out-of-bounds check.
pub fn set_nth_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
scalar: AnyObject<'ctx>,
) {
// Sanity check on scalar's `dtype`
assert!(ctx.unifier.unioned(scalar.ty, self.dtype));
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
ctx.builder.build_store(pscalar, scalar.value).unwrap();
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Please refer to the IRRT implementation to see its purpose.
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_ndarray_set_strides_by_shape(generator, ctx, self.instance);
}
/// Copy data from another ndarray.
///
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
/// do not matter. The copying order is determined by how their flattened views look.
///
/// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayObject<'ctx>,
) {
assert!(ctx.unifier.unioned(self.dtype, src.dtype), "self and src dtype should match");
call_nac3_ndarray_copy_data(generator, ctx, src.instance, self.instance);
}
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated on the stack.
//e
/// The returned ndarray's content will be:
/// - `data`: set to `nullptr`.
/// - `itemsize`: set to the `sizeof()` of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
pub fn alloca<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
ndims: u64,
name: &str,
) -> Self {
let sizet_model = IntModel(SizeT);
let ndarray_model = StructModel(NDArray);
let ndarray_data_model = PtrModel(IntModel(Byte));
let pndarray = ndarray_model.alloca(generator, ctx, name);
let data = ndarray_data_model.nullptr(generator, ctx.ctx);
pndarray.set(ctx, |f| f.data, data);
let itemsize = ctx.get_llvm_type(generator, dtype).size_of().unwrap();
let itemsize =
sizet_model.s_extend_or_bit_cast(generator, ctx, itemsize, "alloca_itemsize");
pndarray.set(ctx, |f| f.itemsize, itemsize);
let ndims_val = sizet_model.constant(generator, ctx.ctx, ndims);
pndarray.set(ctx, |f| f.ndims, ndims_val);
let shape = sizet_model.array_alloca(generator, ctx, ndims_val.value, "alloca_shape");
pndarray.set(ctx, |f| f.shape, shape);
let strides = sizet_model.array_alloca(generator, ctx, ndims_val.value, "alloca_strides");
pndarray.set(ctx, |f| f.strides, strides);
NDArrayObject { dtype, ndims, instance: pndarray }
}
/// Convenience function.
/// Like [`NDArrayObject::alloca_uninitialized`] but directly takes the typechecker type of the ndarray.
pub fn alloca_ndarray_type<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray_ty: Type,
name: &str,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
Self::alloca(generator, ctx, dtype, ndims, name)
}
/// Convenience function. Allocate an [`NDArrayObject`] with a statically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_constant_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[u64],
name: &str,
) -> Self {
let sizet_model = IntModel(SizeT);
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
for (i, dim) in shape.iter().enumerate() {
let dim = sizet_model.constant(generator, ctx.ctx, *dim);
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, dim);
}
ndarray
}
/// Convenience function. Allocate an [`NDArrayObject`] with a dynamically known shape.
///
/// The returned [`NDArrayObject`]'s `data` and `strides` are uninitialized.
pub fn alloca_dynamic_shape<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dtype: Type,
shape: &[Int<'ctx, SizeT>],
name: &str,
) -> Self {
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, shape.len() as u64, name);
// Write shape
let dst_shape = ndarray.instance.get(generator, ctx, |f| f.shape, "shape");
for (i, dim) in shape.iter().enumerate() {
dst_shape.offset_const(generator, ctx, i as u64, "").store(ctx, *dim);
}
ndarray
}
/// Clone/Copy this ndarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents over.
///
/// The new ndarray will own its data and will be C-contiguous.
#[must_use]
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Self {
let clone = NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, name);
let shape = self.instance.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
clone.copy_shape_from_array(generator, ctx, shape);
clone.create_data(generator, ctx);
clone.copy_data_from(generator, ctx, *self);
clone
}
/// Get this ndarray's `ndims` as an LLVM constant.
pub fn get_ndims<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &'ctx Context,
) -> Int<'ctx, SizeT> {
let sizet_model = IntModel(SizeT);
sizet_model.constant(generator, ctx, self.ndims)
}
/// Get if this ndarray's `np.size` is `0` - containing no content.
pub fn is_empty<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
let sizet_model = IntModel(SizeT);
let size = self.size(generator, ctx);
size.compare(ctx, IntPredicate::EQ, sizet_model.const_0(generator, ctx.ctx), "is_empty")
}
/// Return true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
///
/// This is a staticially known property of ndarrays. This is why it is returning
/// a Rust boolean instead of a [`BasicValue`].
#[must_use]
pub fn is_unsized(&self) -> bool {
self.ndims == 0
}
/// If this ndarray is unsized, return its sole value as a [`AnyObject`]. Otherwise, do nothing.
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
if self.is_unsized() {
// NOTE: `np.size(self) == 0` here is never possible.
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
ScalarOrNDArray::Scalar(self.get_nth_scalar(generator, ctx, zero))
} else {
ScalarOrNDArray::NDArray(*self)
}
}
/// Initialize an ndarray's `data` by allocating a buffer on the stack.
/// The allocated data buffer is considered to be *owned* by the ndarray.
///
/// `strides` of the ndarray will also be updated with `set_strides_by_shape`.
///
/// `shape` and `itemsize` of the ndarray ***must*** be initialized first.
pub fn create_data<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
let byte_model = IntModel(Byte);
let nbytes = self.nbytes(generator, ctx);
let data = byte_model.array_alloca(generator, ctx, nbytes.value, "data");
self.instance.set(ctx, |f| f.data, data);
self.update_strides_by_shape(generator, ctx);
}
/// Copy shape dimensions from an array.
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let dst_shape = self.instance.get(generator, ctx, |f| f.shape, "dst_shape");
let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_shape, src_shape, num_items);
}
/// Copy shape dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_shape = src_ndarray.instance.get(generator, ctx, |f| f.shape, "src_shape");
self.copy_shape_from_array(generator, ctx, src_shape);
}
/// Copy strides dimensions from an array.
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_strides: Ptr<'ctx, IntModel<SizeT>>,
) {
let dst_strides = self.instance.get(generator, ctx, |f| f.strides, "dst_strides");
let num_items = self.get_ndims(generator, ctx.ctx).value;
call_memcpy_model(generator, ctx, dst_strides, src_strides, num_items);
}
/// Copy strides dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayObject<'ctx>,
) {
assert_eq!(self.ndims, src_ndarray.ndims);
let src_strides = src_ndarray.instance.get(generator, ctx, |f| f.strides, "src_strides");
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Iterate through every element in the ndarray.
///
/// `body` also access to [`BreakContinueHooks`] to short-circuit.
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
NDIterHandle<'ctx>,
) -> Result<(), String>,
{
gen_for_callback(
generator,
ctx,
Some("ndarray_foreach"),
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
)
}
/// Make sure the ndarray is at least `ndmin`-dimensional.
///
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended to the shape.
/// If this ndarray's `ndims` is not less than `ndmin`, this function does nothing and return this ndarray.
#[must_use]
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndmin: u64,
) -> Self {
if self.ndims < ndmin {
let mut indices = vec![];
for _ in self.ndims..ndmin {
indices.push(RustNDIndex::NewAxis);
}
indices.push(RustNDIndex::Ellipsis);
self.index(generator, ctx, &indices, "atleast_nd_ndarray")
} else {
*self
}
}
/// Fill the ndarray with a scalar.
///
/// `fill_value` must have the same LLVM type as the `dtype` of this ndarray.
pub fn fill<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
scalar: AnyObject<'ctx>,
) {
// Sanity check on scalar's type.
assert!(ctx.unifier.unioned(self.dtype, scalar.ty));
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, scalar.value).unwrap();
Ok(())
})
.unwrap();
}
/// Create a reshaped view on this ndarray like `np.reshape()`.
///
/// If there is a `-1` in `new_shape`, it will be resolved; `new_shape` would **NOT** be modified as a result.
///
/// If reshape without copying is impossible, this function will allocate a new ndarray and copy contents.
///
/// * `new_ndims` - The number of dimensions of `new_shape` as a [`Type`].
/// * `new_shape` - The target shape to do `np.reshape()`.
#[must_use]
pub fn reshape_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
new_ndims: u64,
new_shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
// TODO: The current criterion for whether to do a full copy or not is by checking `is_c_contiguous`,
// but this is not optimal - there are cases when the ndarray is not contiguous but could be reshaped
// without copying data. Look into how numpy does it.
let current_bb = ctx.builder.get_insert_block().unwrap();
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
let dst_ndarray =
NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray");
dst_ndarray.copy_shape_from_array(generator, ctx, new_shape);
// Reolsve negative indices
let size = self.size(generator, ctx);
let dst_ndims = dst_ndarray.get_ndims(generator, ctx.ctx);
let dst_shape =
dst_ndarray.instance.get(generator, ctx, |f| f.shape, "reshaped_ndarray_shape");
call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, dst_ndims, dst_shape);
let is_c_contiguous = self.is_c_contiguous(generator, ctx);
ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap();
// Inserting into then_bb: reshape is possible without copying
ctx.builder.position_at_end(then_bb);
dst_ndarray.update_strides_by_shape(generator, ctx);
dst_ndarray.instance.set(
ctx,
|f| f.data,
self.instance.get(generator, ctx, |f| f.data, "data"),
);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Inserting into else_bb: reshape is impossible without copying
ctx.builder.position_at_end(else_bb);
dst_ndarray.create_data(generator, ctx);
dst_ndarray.copy_data_from(generator, ctx, *self);
ctx.builder.build_unconditional_branch(end_bb).unwrap();
// Reposition for continuation
ctx.builder.position_at_end(end_bb);
dst_ndarray
}
/// Create a flattened view of this ndarray, like `np.ravel()`.
///
/// Uses [`NDArrayObject::reshape_or_copy`] under-the-hood so ndarray may or may not be copied.
#[must_use]
pub fn ravel_or_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
// Define models
let sizet_model = IntModel(SizeT);
let num0 = sizet_model.const_0(generator, ctx.ctx);
let num1 = sizet_model.const_1(generator, ctx.ctx);
let num_neg1 = sizet_model.const_all_1s(generator, ctx.ctx);
// Create `[-1]` and pass to `reshape_or_copy`.
let new_shape = sizet_model.array_alloca(generator, ctx, num1.value, "new_shape");
new_shape.offset(generator, ctx, num0.value, "").store(ctx, num_neg1);
self.reshape_or_copy(generator, ctx, 1, new_shape)
}
/// Create a transposed view on this ndarray like `np.transpose(<ndarray>, <axes> = None)`.
/// * `axes` - If specified, should be an array of the permutation (negative indices are **allowed**).
#[must_use]
pub fn transpose<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
axes: Option<Ptr<'ctx, IntModel<SizeT>>>,
) -> Self {
// Define models
let sizet_model = IntModel(SizeT);
let transposed_ndarray =
NDArrayObject::alloca(generator, ctx, self.dtype, self.ndims, "transposed_ndarray");
let num_axes = self.get_ndims(generator, ctx.ctx);
// `axes = nullptr` if `axes` is unspecified.
let axes = axes.unwrap_or_else(|| PtrModel(sizet_model).nullptr(generator, ctx.ctx));
call_nac3_ndarray_transpose(
generator,
ctx,
self.instance,
transposed_ndarray.instance,
num_axes,
axes,
);
transposed_ndarray
}
/// Check if this `NDArray` can be used as an `out` ndarray for an operation.
///
/// Raise an exception if the shapes do not match.
pub fn check_can_be_written_by_out<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
out_ndims: u64,
out_shape: Ptr<'ctx, IntModel<SizeT>>,
) {
let sizet_model = IntModel(SizeT);
let ndarray_ndims = self.get_ndims(generator, ctx.ctx);
let ndarray_shape = self.instance.get(generator, ctx, |f| f.shape, "shape");
let output_ndims = sizet_model.constant(generator, ctx.ctx, out_ndims);
let output_shape = out_shape;
call_nac3_ndarray_util_assert_output_shape_same(
generator,
ctx,
ndarray_ndims,
ndarray_shape,
output_ndims,
output_shape,
);
}
/// Create the shape tuple of this ndarray like `np.shape(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_shape_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Don't return a tuple of int32s.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.shape, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::create(generator, ctx, objects, "shape")
}
/// Create the strides tuple of this ndarray like `np.strides(<ndarray>)`.
///
/// The returned integers in the tuple are in int32.
pub fn make_strides_tuple<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> TupleObject<'ctx> {
// TODO: Don't return a tuple of int32s.
let mut objects = Vec::with_capacity(self.ndims as usize);
for i in 0..self.ndims {
let dim = self
.instance
.get(generator, ctx, |f| f.strides, "")
.offset_const(generator, ctx, i, "")
.load(generator, ctx, "dim");
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
objects.push(AnyObject {
ty: ctx.primitives.int32,
value: dim.value.as_basic_value_enum(),
});
}
TupleObject::create(generator, ctx, objects, "strides")
}
}
/// TODO: Document me
#[derive(Debug, Clone, Copy)]
pub enum NDArrayOut<'ctx> {
NewNDArray { dtype: Type },
WriteToNDArray { ndarray: NDArrayObject<'ctx> },
}

View File

@ -1,53 +0,0 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use crate::codegen::{model::*, structure::SimpleNDArray, CodeGenContext, CodeGenerator};
use super::NDArrayObject;
pub fn perform_nalgebra_call<'ctx, 'a, const NUM_INPUTS: usize, const NUM_OUTPUTS: usize, G, F>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: [NDArrayObject<'ctx>; NUM_INPUTS],
output_ndims: [u64; NUM_OUTPUTS],
invoke_function: F,
) -> [NDArrayObject<'ctx>; NUM_OUTPUTS]
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut CodeGenContext<'ctx, 'a>,
[BasicValueEnum<'ctx>; NUM_INPUTS],
[BasicValueEnum<'ctx>; NUM_OUTPUTS],
),
{
// TODO: Allow stacked inputs. See NumPy docs.
let f64_model = FloatModel(Float64);
let simple_ndarray_model = StructModel(SimpleNDArray { item: f64_model });
// Prepare inputs & outputs and invoke
let inputs = inputs.map(|input| {
// Sanity check. Typechecker ensures this.
assert!(ctx.unifier.unioned(input.dtype, ctx.primitives.float));
input
.make_simple_ndarray(generator, ctx, FloatModel(Float64), "nalgebra_input")
.value
.as_basic_value_enum()
});
let outputs = [simple_ndarray_model.alloca(generator, ctx, "nalgebra_output"); NUM_OUTPUTS];
invoke_function(ctx, inputs, outputs.map(|output| output.value.as_basic_value_enum()));
// Turn the outputs into strided NDArrays
let mut output_i = 0;
outputs.map(|output| {
let out = NDArrayObject::from_simple_ndarray(
generator,
ctx,
output,
ctx.primitives.float,
output_ndims[output_i],
);
output_i += 1;
out
})
}

View File

@ -1,88 +0,0 @@
use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
use crate::codegen::{
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
model::*,
object::AnyObject,
structure::NDIter,
CodeGenContext, CodeGenerator,
};
use super::NDArrayObject;
#[derive(Debug, Clone)]
pub struct NDIterHandle<'ctx> {
ndarray: NDArrayObject<'ctx>,
instance: Ptr<'ctx, StructModel<NDIter>>,
indices: Ptr<'ctx, IntModel<SizeT>>,
}
impl<'ctx> NDIterHandle<'ctx> {
pub fn new<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayObject<'ctx>,
) -> Self {
let nditer = StructModel(NDIter).alloca(generator, ctx, "nditer");
let ndims = ndarray.get_ndims(generator, ctx.ctx);
let indices = IntModel(SizeT).array_alloca(generator, ctx, ndims.value, "indices");
call_nac3_nditer_initialize(generator, ctx, nditer, ndarray.instance, indices);
NDIterHandle { ndarray, instance: nditer, indices }
}
#[must_use]
pub fn has_next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Bool> {
call_nac3_nditer_has_next(generator, ctx, self.instance)
}
pub fn next<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
call_nac3_nditer_next(generator, ctx, self.instance);
}
#[must_use]
pub fn get_pointer<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> PointerValue<'ctx> {
let elem_ty = ctx.get_llvm_type(generator, self.ndarray.dtype);
let p = self.instance.get(generator, ctx, |f| f.element, "element");
ctx.builder
.build_pointer_cast(p.value, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap()
}
#[must_use]
pub fn get_scalar<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
let p = self.get_pointer(generator, ctx);
let value = ctx.builder.build_load(p, "value").unwrap();
AnyObject { ty: self.ndarray.dtype, value }
}
#[must_use]
pub fn get_index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
self.instance.get(generator, ctx, |f| f.nth, "index")
}
#[must_use]
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
self.indices
}
}

View File

@ -1,159 +0,0 @@
use std::cmp::max;
use crate::codegen::{
irrt::{
call_nac3_ndarray_float64_matmul_at_least_2d, call_nac3_ndarray_matmul_calculate_shapes,
},
model::*,
object::ndarray::indexing::RustNDIndex,
CodeGenContext, CodeGenerator,
};
use super::{NDArrayObject, NDArrayOut};
impl<'ctx> NDArrayObject<'ctx> {
/// TODO: Document me
fn matmul_helper<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: Self,
b: Self,
) -> Self {
assert!(a.ndims >= 2);
assert!(b.ndims >= 2);
assert!(ctx.unifier.unioned(ctx.primitives.float, a.dtype));
assert!(ctx.unifier.unioned(ctx.primitives.float, b.dtype));
let sizet_model = IntModel(SizeT);
let final_ndims_int = max(a.ndims, b.ndims);
let a_ndims = a.get_ndims(generator, ctx.ctx);
let a_shape = a.instance.get(generator, ctx, |f| f.shape, "a_shape");
let b_ndims = b.get_ndims(generator, ctx.ctx);
let b_shape = b.instance.get(generator, ctx, |f| f.shape, "b_shape");
let final_ndims = sizet_model.constant(generator, ctx.ctx, final_ndims_int);
let new_a_shape =
sizet_model.array_alloca(generator, ctx, final_ndims.value, "new_a_shape");
let new_b_shape =
sizet_model.array_alloca(generator, ctx, final_ndims.value, "new_b_shape");
let dst_shape = sizet_model.array_alloca(generator, ctx, final_ndims.value, "dst_shape");
call_nac3_ndarray_matmul_calculate_shapes(
generator,
ctx,
a_ndims,
a_shape,
b_ndims,
b_shape,
final_ndims,
new_a_shape,
new_b_shape,
dst_shape,
);
let new_a = a.broadcast_to(generator, ctx, final_ndims_int, new_a_shape);
let new_b = b.broadcast_to(generator, ctx, final_ndims_int, new_b_shape);
let dst = NDArrayObject::alloca(
generator,
ctx,
ctx.primitives.float,
final_ndims_int,
"matmul_result",
);
dst.copy_shape_from_array(generator, ctx, dst_shape);
dst.create_data(generator, ctx);
call_nac3_ndarray_float64_matmul_at_least_2d(
generator,
ctx,
new_a.instance,
new_b.instance,
dst.instance,
);
dst
}
/// Perform `np.matmul` according to the rules in
/// <https://numpy.org/doc/stable/reference/generated/numpy.matmul.html>.
///
/// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`].
pub fn matmul<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: Self,
b: Self,
out: NDArrayOut<'ctx>,
) -> Self {
// Sanity check, but type inference should prevent this.
assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input");
/*
If both arguments are 2-D they are multiplied like conventional matrices.
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
*/
let new_a = if a.ndims == 1 {
// Prepend 1 to its dimensions
a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis], "new_a")
} else {
a
};
let new_b = if b.ndims == 1 {
// Append 1 to its dimensions
b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis], "new_a")
} else {
b
};
// NOTE: `result` will always be a newly allocated ndarray.
// Current implementation cannot do in-place matrix muliplication.
let mut result = NDArrayObject::matmul_helper(generator, ctx, new_a, new_b);
let i32_model = IntModel(Int32); // TODO: Upgrade to SizeT
let zero = i32_model.const_0(generator, ctx.ctx);
if a.ndims == 1 {
// Remove the prepended 1
result = result.index(
generator,
ctx,
&[RustNDIndex::SingleElement(zero), RustNDIndex::Ellipsis],
"result_no_prepend_1",
);
}
if b.ndims == 1 {
// Remove the appended 1
result = result.index(
generator,
ctx,
&[RustNDIndex::Ellipsis, RustNDIndex::SingleElement(zero)],
"result_no_append_1",
);
}
match out {
NDArrayOut::NewNDArray { dtype } => {
// We don't support auto-casting right now, nor anything other than float64.
// Force the output dtype to be float64.
assert!(ctx.unifier.unioned(ctx.primitives.float, dtype));
result
}
NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => {
// TODO: It is possible to check the shapes before computing the matmul to save resources.
let result_shape = result.instance.get(generator, ctx, |f| f.shape, "result_shape");
out_ndarray.check_can_be_written_by_out(generator, ctx, result.ndims, result_shape);
// TODO: We can just set `out_ndarray.data` to `result.data`. Should we?
out_ndarray.copy_data_from(generator, ctx, result);
out_ndarray
}
}
}
}

View File

@ -1,131 +0,0 @@
use inkwell::values::{BasicValue, BasicValueEnum};
use crate::{
codegen::{model::*, object::AnyObject, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::NDArrayObject;
impl<'ctx> AnyObject<'ctx> {
/// Promote this scalar to an unsized ndarray (like doing `np.asarray`).
///
/// The scalar value is allocated onto the stack, and the ndarray's `data` will point to that
/// allocated value.
pub fn as_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
let pbyte_model = PtrModel(IntModel(Byte));
// We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(self.value.get_type(), "as_ndarray_scalar").unwrap();
ctx.builder.build_store(data, self.value).unwrap();
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
let ndarray = NDArrayObject::alloca(generator, ctx, self.ty, 0, "scalar_ndarray");
ndarray.instance.set(ctx, |f| f.data, data);
ndarray
}
}
/// A convenience enum for implementing scalar/ndarray agnostic utilities.
#[derive(Debug, Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(AnyObject<'ctx>),
NDArray(NDArrayObject<'ctx>),
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.value,
ScalarOrNDArray::NDArray(ndarray) => ndarray.instance.value.as_basic_value_enum(),
}
}
#[must_use]
pub fn into_scalar(&self) -> AnyObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
ScalarOrNDArray::Scalar(scalar) => *scalar,
}
}
#[must_use]
pub fn into_ndarray(&self) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(_scalar) => panic!("Got Scalar"),
}
}
/// Convert this [`ScalarOrNDArray`] to an ndarray - behaves like `np.asarray`.
/// - If this is an ndarray, the ndarray is returned.
/// - If this is a scalar, an unsized ndarray view is created on it.
pub fn as_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> NDArrayObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(ndarray) => *ndarray,
ScalarOrNDArray::Scalar(scalar) => scalar.as_ndarray(generator, ctx),
}
}
#[must_use]
pub fn dtype(&self) -> Type {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(scalar) => Ok(*scalar),
ScalarOrNDArray::NDArray(_ndarray) => Err(()),
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
match value {
ScalarOrNDArray::Scalar(_scalar) => Err(()),
ScalarOrNDArray::NDArray(ndarray) => Ok(*ndarray),
}
}
}
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending on its [`Type`].
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> ScalarOrNDArray<'ctx> {
// TODO: Automatically convert a list into an ndarray?
match &*ctx.unifier.get_ty(object.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let ndarray = NDArrayObject::from_object(generator, ctx, object);
ScalarOrNDArray::NDArray(ndarray)
}
_ => {
let scalar = AnyObject { ty: object.ty, value: object.value };
ScalarOrNDArray::Scalar(scalar)
}
}
}

View File

@ -1,111 +0,0 @@
use util::gen_for_model_auto;
use crate::{
codegen::{
model::*,
object::{list::ListObject, tuple::TupleObject, AnyObject},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::TypeEnum,
};
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
///
/// * `sequence` - The `sequence` parameter.
/// * `sequence_ty` - The typechecker type of `sequence`
///
/// The `sequence` argument type may only be one of the following:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// All `int32` values will be sign-extended to `SizeT`.
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
input_sequence: AnyObject<'ctx>,
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
let sizet_model = IntModel(SizeT);
let zero = sizet_model.const_0(generator, ctx.ctx);
let one = sizet_model.const_1(generator, ctx.ctx);
// The result `list` to return.
match &*ctx.unifier.get_ty(input_sequence.ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
// Check `input_sequence`
let input_sequence = ListObject::from_object(generator, ctx, input_sequence);
let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
// Load all the `int32`s from the input_sequence, cast them to `SizeT`, and store them into `result`
gen_for_model_auto(generator, ctx, zero, len, one, |generator, ctx, _hooks, i| {
// Load the i-th int32 in the input sequence
let int = input_sequence
.instance
.get(generator, ctx, |f| f.items, "")
.offset(generator, ctx, i.value, "")
.load(generator, ctx, "")
.value
.into_int_value();
// Cast to SizeT
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
// Store
result.offset(generator, ctx, i.value, "int").store(ctx, int);
Ok(())
})
.unwrap();
(len, result)
}
TypeEnum::TTuple { .. } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
let input_sequence = TupleObject::from_object(ctx, input_sequence);
let len_int = input_sequence.len_static();
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
for i in 0..len_int {
// Get the i-th element off of the tuple and load it into `result`.
let int = input_sequence.get(ctx, i, "dim").value.into_int_value();
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
result.offset(generator, ctx, offset.value, "int").store(ctx, int);
}
(len, result)
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
let input_int = input_sequence.value.into_int_value();
let len = sizet_model.const_1(generator, ctx.ctx);
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, input_int, "int");
// Storing into result[0]
result.store(ctx, int);
(len, result)
}
_ => panic!(
"encountered unknown sequence type: {}",
ctx.unifier.stringify(input_sequence.ty)
),
}
}

View File

@ -1,41 +0,0 @@
use crate::codegen::{
irrt::calculate_len_for_slice_range, model::*, structure::RangeModel, CodeGenContext,
CodeGenerator,
};
use super::AnyObject;
/// A `range` in NAC3
pub struct RangeObject<'ctx> {
pub instance: Ptr<'ctx, RangeModel>,
}
impl<'ctx> RangeObject<'ctx> {
pub fn from_object<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
object: AnyObject<'ctx>,
) -> Self {
assert!(ctx.unifier.unioned(ctx.primitives.range, object.ty)); // Sanity check on type.
let model = PtrModel(RangeModel::default());
let instance = model.check_value(generator, ctx.ctx, object.value).unwrap();
RangeObject { instance }
}
/// Get the `len()` of this range.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Int32> {
let start = self.instance.gep_start(generator, ctx, "").load(generator, ctx, "start");
let stop = self.instance.gep_stop(generator, ctx, "").load(generator, ctx, "stop");
let step = self.instance.gep_step(generator, ctx, "").load(generator, ctx, "step");
// TODO: Refactor this
let len =
calculate_len_for_slice_range(generator, ctx, start.value, stop.value, step.value);
IntModel(Int32).check_value(generator, ctx.ctx, len).unwrap()
}
}

View File

@ -1,113 +0,0 @@
use core::panic;
use inkwell::values::StructValue;
use itertools::Itertools;
use crate::{
codegen::{model::*, CodeGenContext, CodeGenerator},
typecheck::typedef::{Type, TypeEnum},
};
use super::AnyObject;
/// A NAC3 tuple object.
///
/// NOTE: This struct has no copy trait.
#[derive(Debug, Clone)]
pub struct TupleObject<'ctx> {
/// The type of the tuple.
pub tys: Vec<Type>,
/// The underlying LLVM value of this tuple.
pub value: StructValue<'ctx>,
}
impl<'ctx> TupleObject<'ctx> {
// NOTE: There is no Model abstraction for Tuples with arbitrary lengths.
// Everything has to be done raw with Inkwell.
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
// TODO: Keep `is_vararg_ctx` from TTuple?
// Sanity check on object type.
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
panic!(
"Expected type to be a TypeEnum::TTuple, got {}",
ctx.unifier.stringify(object.ty)
);
};
let value = object.value.into_struct_value();
let value_num_fields = value.get_type().count_fields() as usize;
assert!(
value_num_fields == tys.len(),
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
tys.len(),
value_num_fields
);
TupleObject { tys: tys.clone(), value }
}
/// Convenience function. Create a [`TupleObject`] from an iterator of objects.
pub fn create<I, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
objects: I,
name: &str,
) -> Self
where
I: IntoIterator<Item = AnyObject<'ctx>>,
{
let (values, tys): (Vec<_>, Vec<_>) =
objects.into_iter().map(|object| (object.value, object.ty)).unzip();
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
for (i, val) in values.into_iter().enumerate() {
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
ctx.builder.build_store(pval, val).unwrap();
}
let value = ctx.builder.build_load(pllvm_tuple, name).unwrap().into_struct_value();
TupleObject { tys, value }
}
/// Get the `len()` of this tuple statically.
///
/// We statically know the lengths of tuples in NAC3 when compiling.
#[must_use]
pub fn len_static(&self) -> usize {
self.tys.len()
}
/// Get the `len()` of this tuple.
#[must_use]
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, SizeT> {
IntModel(SizeT).constant(generator, ctx.ctx, self.len_static() as u64)
}
/// Check if this tuple is an empty/unit tuple.
#[must_use]
pub fn is_empty(&self) -> bool {
self.len_static() == 0
}
/// Get the `i`-th (0-based) object in this tuple.
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
assert!(
i < self.len_static(),
"Tuple object with length {} have index {i}",
self.len_static()
);
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
let ty = self.tys[i];
AnyObject { ty, value }
}
}

View File

@ -1,27 +1,3 @@
use super::model::*;
use super::object::ndarray::indexing::util::gen_ndarray_subscript_ndindexes;
use super::object::ndarray::scalar::split_scalar_or_ndarray;
use super::object::ndarray::NDArrayObject;
use super::object::AnyObject;
use super::{
super::symbol_resolver::ValueEnum,
expr::destructure_range,
irrt::{handle_slice_indices, list_slice_assignment},
structure::{CSlice, Exception},
CodeGenContext, CodeGenerator, Int32, IntModel, Ptr, StructModel,
};
use crate::{
codegen::{
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
expr::gen_binop_expr,
gen_in_range_check,
},
toplevel::{DefinitionId, TopLevelDef},
typecheck::{
magic_methods::Binop,
typedef::{iter_type_vars, FunSignature, Type, TypeEnum},
},
};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
@ -30,10 +6,28 @@ use inkwell::{
IntPredicate, IntPredicate,
}; };
use itertools::{izip, Itertools}; use itertools::{izip, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef, Constant, ExcepthandlerKind, Expr, ExprKind, Location, Stmt, StmtKind, StrRef,
}; };
use super::{
expr::{destructure_range, gen_binop_expr},
gen_in_range_check,
irrt::{handle_slice_indices, list_slice_assignment},
macros::codegen_unreachable,
values::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
CodeGenContext, CodeGenerator,
};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::{
magic_methods::Binop,
typedef::{iter_type_vars, FunSignature, Type, TypeEnum},
},
};
/// See [`CodeGenerator::gen_var_alloc`]. /// See [`CodeGenerator::gen_var_alloc`].
pub fn gen_var<'ctx>( pub fn gen_var<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -127,7 +121,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
return Ok(None); return Ok(None);
}; };
let BasicValueEnum::PointerValue(ptr) = val else { let BasicValueEnum::PointerValue(ptr) = val else {
unreachable!(); codegen_unreachable!(ctx);
}; };
unsafe { unsafe {
ctx.builder.build_in_bounds_gep( ctx.builder.build_in_bounds_gep(
@ -141,7 +135,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
} }
.unwrap() .unwrap()
} }
_ => unreachable!(), _ => codegen_unreachable!(ctx),
})) }))
} }
@ -182,6 +176,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
} }
} }
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?; let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
// Perform i1 <-> i8 conversion as needed
let val = if ctx.unifier.unioned(target.custom.unwrap(), ctx.primitives.bool) {
generator.bool_to_i8(ctx, val.into_int_value()).into()
} else {
val
};
ctx.builder.build_store(ptr, val).unwrap(); ctx.builder.build_store(ptr, val).unwrap();
} }
}; };
@ -199,12 +201,12 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
// Deconstruct the tuple `value` // Deconstruct the tuple `value`
let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)? let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)?
else { else {
unreachable!() codegen_unreachable!(ctx)
}; };
// NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer. // NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer.
let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else { let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else {
unreachable!(); codegen_unreachable!(ctx);
}; };
assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len()); assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len());
@ -264,7 +266,7 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
// Now assign with that sub-tuple to the starred target. // Now assign with that sub-tuple to the starred target.
generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?; generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?;
} else { } else {
unreachable!() // The typechecker ensures this codegen_unreachable!(ctx) // The typechecker ensures this
} }
// Handle assignment after the starred target // Handle assignment after the starred target
@ -308,11 +310,13 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, target_ty)? .to_basic_value_enum(ctx, generator, target_ty)?
.into_pointer_value(); .into_pointer_value();
let target = ListValue::from_ptr_val(target, llvm_usize, None); let target = ListValue::from_pointer_value(target, llvm_usize, None);
if let ExprKind::Slice { .. } = &key.node { if let ExprKind::Slice { .. } = &key.node {
// Handle assigning to a slice // Handle assigning to a slice
let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() }; let ExprKind::Slice { lower, upper, step } = &key.node else {
codegen_unreachable!(ctx)
};
let Some((start, end, step)) = handle_slice_indices( let Some((start, end, step)) = handle_slice_indices(
lower, lower,
upper, upper,
@ -327,7 +331,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
let value = let value =
value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value(); value.to_basic_value_enum(ctx, generator, value_ty)?.into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None); let value = ListValue::from_pointer_value(value, llvm_usize, None);
let target_item_ty = ctx.get_llvm_type(generator, target_item_ty); let target_item_ty = ctx.get_llvm_type(generator, target_item_ty);
let Some(src_ind) = handle_slice_indices( let Some(src_ind) = handle_slice_indices(
@ -407,45 +411,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{ {
// Handle NDArray item assignment // Handle NDArray item assignment
// Process target todo!("ndarray subscript assignment is not yet implemented");
let target = generator
.gen_expr(ctx, target)?
.unwrap()
.to_basic_value_enum(ctx, generator, target_ty)?;
let target = AnyObject { value: target, ty: target_ty };
// Process key
let key = gen_ndarray_subscript_ndindexes(generator, ctx, key)?;
// Process value
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
let value = AnyObject { value, ty: value_ty };
/*
Reference code:
```python
target = target[key]
value = np.asarray(value)
shape = np.broadcast_shape((target, value))
target = np.broadcast_to(target, shape)
value = np.broadcast_to(value, shape)
...and finally copy 1-1 from value to target.
```
*/
let target = NDArrayObject::from_object(generator, ctx, target);
let target = target.index(generator, ctx, &key, "assign_target_ndarray");
let value = split_scalar_or_ndarray(generator, ctx, value).as_ndarray(generator, ctx);
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);
let target = broadcast_result.ndarrays[0];
let value = broadcast_result.ndarrays[1];
target.copy_data_from(generator, ctx, value);
} }
_ => { _ => {
panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty)); panic!("encountered unknown target type: {}", ctx.unifier.stringify(target_ty));
@ -460,7 +426,9 @@ pub fn gen_for<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() }; let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else {
codegen_unreachable!(ctx)
};
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -495,14 +463,15 @@ pub fn gen_for<G: CodeGenerator>(
TypeEnum::TObj { obj_id, .. } TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() => if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{ {
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); let iter_val =
RangeValue::from_pointer_value(iter_val.into_pointer_value(), Some("range"));
// Internal variable for loop; Cannot be assigned // Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let Some(target_i) = let Some(target_i) =
generator.gen_store_target(ctx, target, Some("for.target.addr"))? generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else { else {
unreachable!() codegen_unreachable!(ctx)
}; };
let (start, stop, step) = destructure_range(ctx, iter_val); let (start, stop, step) = destructure_range(ctx, iter_val);
@ -663,9 +632,9 @@ pub struct BreakContinueHooks<'ctx> {
/// ``` /// ```
/// ///
/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The /// * `init` - A lambda containing IR statements declaring and initializing loop variables. The
/// return value is a [Clone] value which will be passed to the other lambdas. /// return value is a [Clone] value which will be passed to the other lambdas.
/// * `cond` - A lambda containing IR statements checking whether the loop should continue /// * `cond` - A lambda containing IR statements checking whether the loop should continue
/// executing. The result value must be an `i1` indicating if the loop should continue. /// executing. The result value must be an `i1` indicating if the loop should continue.
/// * `body` - A lambda containing IR statements within the loop body. /// * `body` - A lambda containing IR statements within the loop body.
/// * `update` - A lambda containing IR statements updating loop variables. /// * `update` - A lambda containing IR statements updating loop variables.
pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
@ -748,9 +717,9 @@ where
/// ``` /// ```
/// ///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used /// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable. /// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum /// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
/// value should be treated as inclusive (as opposed to exclusive). /// value should be treated as inclusive (as opposed to exclusive).
/// * `body` - A lambda containing IR statements within the loop body. /// * `body` - A lambda containing IR statements within the loop body.
/// * `incr_val` - The value to increment the loop variable on each iteration. /// * `incr_val` - The value to increment the loop variable on each iteration.
pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>( pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
@ -821,12 +790,12 @@ where
/// ///
/// - `is_unsigned`: Whether to treat the values of the `range` as unsigned. /// - `is_unsigned`: Whether to treat the values of the `range` as unsigned.
/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like /// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like
/// iterable. /// iterable.
/// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like /// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like
/// iterable. This value will be extended to the size of `start`. /// iterable. This value will be extended to the size of `start`.
/// - `stop_inclusive`: Whether the stop value should be treated as inclusive. /// - `stop_inclusive`: Whether the stop value should be treated as inclusive.
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like /// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`. /// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body. /// - `body_fn`: A lambda of IR statements within the loop body.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
@ -847,7 +816,7 @@ where
BodyFn: FnOnce( BodyFn: FnOnce(
&mut G, &mut G,
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks, BreakContinueHooks<'ctx>,
IntValue<'ctx>, IntValue<'ctx>,
) -> Result<(), String>, ) -> Result<(), String>,
{ {
@ -945,7 +914,7 @@ pub fn gen_while<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() }; let StmtKind::While { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -975,7 +944,7 @@ pub fn gen_while<G: CodeGenerator>(
return Ok(()); return Ok(());
}; };
let BasicValueEnum::IntValue(test) = test else { unreachable!() }; let BasicValueEnum::IntValue(test) = test else { codegen_unreachable!(ctx) };
ctx.builder ctx.builder
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb) .build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
@ -1123,7 +1092,7 @@ pub fn gen_if<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { unreachable!() }; let StmtKind::If { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -1246,11 +1215,11 @@ pub fn exn_constructor<'ctx>(
let zelf_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) { let zelf_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) {
obj_id.0 obj_id.0
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let def = defs[zelf_id].read(); let def = defs[zelf_id].read();
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { unreachable!() }; let TopLevelDef::Class { name: zelf_name, .. } = &*def else { codegen_unreachable!(ctx) };
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name); let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name);
unsafe { unsafe {
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap();
@ -1307,36 +1276,47 @@ pub fn exn_constructor<'ctx>(
pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>( pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
exception: Option<Ptr<'ctx, StructModel<Exception>>>, exception: Option<&BasicValueEnum<'ctx>>,
loc: Location, loc: Location,
) { ) {
if let Some(pexn) = exception { if let Some(exception) = exception {
let i32_model = IntModel(Int32); unsafe {
let cslice_model = StructModel(CSlice); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let exception = exception.into_pointer_value();
let file_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr")
.unwrap();
let filename = ctx.gen_string(generator, loc.file.0);
ctx.builder.build_store(file_ptr, filename).unwrap();
let row_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr")
.unwrap();
ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap();
let col_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr")
.unwrap();
ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap();
// Get and store filename let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let filename = loc.file.0; let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap());
let filename = ctx.gen_string(generator, &String::from(filename)).value; let name_ptr = ctx
let filename = cslice_model.check_value(generator, ctx.ctx, filename).unwrap(); .builder
pexn.set(ctx, |f| f.filename, filename); .build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr")
.unwrap();
let row = i32_model.constant(generator, ctx.ctx, loc.row as u64); ctx.builder.build_store(name_ptr, fun_name).unwrap();
pexn.set(ctx, |f| f.line, row); }
let column = i32_model.constant(generator, ctx.ctx, loc.column as u64);
pexn.set(ctx, |f| f.column, column);
let current_fn = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let fn_name = ctx.gen_string(generator, current_fn.get_name().to_str().unwrap());
pexn.set(ctx, |f| f.function, fn_name);
let raise = get_builtins(generator, ctx, "__nac3_raise"); let raise = get_builtins(generator, ctx, "__nac3_raise");
ctx.build_call_or_invoke(raise, &[pexn.value.into()], "raise"); let exception = *exception;
ctx.build_call_or_invoke(raise, &[exception], "raise");
} else { } else {
let resume = get_builtins(generator, ctx, "__nac3_resume"); let resume = get_builtins(generator, ctx, "__nac3_resume");
ctx.build_call_or_invoke(resume, &[], "resume"); ctx.build_call_or_invoke(resume, &[], "resume");
} }
ctx.builder.build_unreachable().unwrap(); ctx.builder.build_unreachable().unwrap();
} }
@ -1347,7 +1327,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
target: &Stmt<Option<Type>>, target: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else { let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else {
unreachable!() codegen_unreachable!(ctx)
}; };
// if we need to generate anything related to exception, we must have personality defined // if we need to generate anything related to exception, we must have personality defined
@ -1424,7 +1404,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
*obj_id *obj_id
} else { } else {
unreachable!() codegen_unreachable!(ctx)
}; };
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name); let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
let exn_id = ctx.resolver.get_string_id(&exception_name); let exn_id = ctx.resolver.get_string_id(&exception_name);
@ -1696,6 +1676,23 @@ pub fn gen_return<G: CodeGenerator>(
} else { } else {
None None
}; };
// Remap boolean return type into i1
let value = value.map(|ret_val| {
// The "return type" of a sret function is in the first parameter
let expected_ty = if ctx.need_sret {
func.get_type().get_param_types()[0]
} else {
func.get_type().get_return_type().unwrap()
};
if matches!(expected_ty, BasicTypeEnum::IntType(ty) if ty.get_bit_width() == 1) {
generator.bool_to_i1(ctx, ret_val.into_int_value()).into()
} else {
ret_val
}
});
if let Some(return_target) = ctx.return_target { if let Some(return_target) = ctx.return_target {
if let Some(value) = value { if let Some(value) = value {
ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap(); ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap();
@ -1706,25 +1703,6 @@ pub fn gen_return<G: CodeGenerator>(
ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap(); ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap();
ctx.builder.build_return(None).unwrap(); ctx.builder.build_return(None).unwrap();
} else { } else {
// Remap boolean return type into i1
let value = value.map(|v| {
let expected_ty = func.get_type().get_return_type().unwrap();
let ret_val = v.as_basic_value_enum();
if expected_ty.is_int_type() && ret_val.is_int_value() {
let ret_type = expected_ty.into_int_type();
let ret_val = ret_val.into_int_value();
if ret_type.get_bit_width() == 1 && ret_val.get_type().get_bit_width() != 1 {
generator.bool_to_i1(ctx, ret_val)
} else {
ret_val
}
.into()
} else {
ret_val
}
});
let value = value.as_ref().map(|v| v as &dyn BasicValue); let value = value.as_ref().map(|v| v as &dyn BasicValue);
ctx.builder.build_return(value).unwrap(); ctx.builder.build_return(value).unwrap();
} }
@ -1793,52 +1771,95 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => { StmtKind::Raise { exc, .. } => {
if let Some(exc) = exc { if let Some(exc) = exc {
let exc = if let Some(v) = generator.gen_expr(ctx, exc)? { let exn = if let ExprKind::Name { id, .. } = &exc.node {
// Handle "raise Exception" short form
let def_id = ctx.resolver.get_identifier_def(*id).map_err(|e| {
format!("{} (at {})", e.iter().next().unwrap(), exc.location)
})?;
let def = ctx.top_level.definitions.read();
let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() else {
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
};
let TypeEnum::TFunc(signature) =
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
else {
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
};
generator
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
.map(Into::into)
} else {
generator.gen_expr(ctx, exc)?
};
let exc = if let Some(v) = exn {
v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())? v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())?
} else { } else {
return Ok(()); return Ok(());
}; };
gen_raise(generator, ctx, Some(&exc), stmt.location);
let pexn_model = PtrModel(StructModel(Exception));
let exn = pexn_model.check_value(generator, ctx.ctx, exc).unwrap();
gen_raise(generator, ctx, Some(exn), stmt.location);
} else { } else {
gen_raise(generator, ctx, None, stmt.location); gen_raise(generator, ctx, None, stmt.location);
} }
} }
StmtKind::Assert { test, msg, .. } => { StmtKind::Assert { test, msg, .. } => {
let byte_model = IntModel(Byte); let test = if let Some(v) = generator.gen_expr(ctx, test)? {
let cslice_model = StructModel(CSlice); v.to_basic_value_enum(ctx, generator, test.custom.unwrap())?
} else {
let Some(test) = generator.gen_expr(ctx, test)? else {
return Ok(()); return Ok(());
}; };
let test = test.to_basic_value_enum(ctx, generator, ctx.primitives.bool)?;
let test = byte_model.check_value(generator, ctx.ctx, test).unwrap(); // Python `bool` is represented as `i8` in nac3core
// Check `msg`
let err_msg = match msg { let err_msg = match msg {
Some(msg) => { Some(msg) => {
let Some(msg) = generator.gen_expr(ctx, msg)? else { if let Some(v) = generator.gen_expr(ctx, msg)? {
v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())?
} else {
return Ok(()); return Ok(());
}; }
let msg = msg.to_basic_value_enum(ctx, generator, ctx.primitives.str)?;
cslice_model.check_value(generator, ctx.ctx, msg).unwrap()
} }
None => ctx.gen_string(generator, ""), None => ctx.gen_string(generator, "").into(),
}; };
ctx.make_assert_impl( ctx.make_assert_impl(
generator, generator,
test.value, generator.bool_to_i1(ctx, test.into_int_value()),
"0:AssertionError", "0:AssertionError",
err_msg, err_msg,
[None, None, None], [None, None, None],
stmt.location, stmt.location,
); );
} }
StmtKind::Global { names, .. } => {
let registered_globals = ctx
.top_level
.definitions
.read()
.iter()
.filter_map(|def| {
if let TopLevelDef::Variable { simple_name, ty, .. } = &*def.read() {
Some((*simple_name, *ty))
} else {
None
}
})
.collect_vec();
for id in names {
let Some((_, ty)) = registered_globals.iter().find(|(name, _)| name == id) else {
return Err(format!("{id} is not a global at {}", stmt.location));
};
let resolver = ctx.resolver.clone();
let ptr = resolver
.get_symbol_value(*id, ctx, generator)
.map(|val| val.to_basic_value_enum(ctx, generator, *ty))
.transpose()?
.map(BasicValueEnum::into_pointer_value)
.unwrap();
ctx.var_assignment.insert(*id, (ptr, None, 0));
}
}
_ => unimplemented!(), _ => unimplemented!(),
}; };
Ok(()) Ok(())

Some files were not shown because too many files have changed in this diff Show More