1
0
forked from M-Labs/nac3

Compare commits

..

473 Commits

Author SHA1 Message Date
1531b6cc98 cargo: update dependencies 2024-12-13 19:42:01 +08:00
9bbc40bbfa flake: update dependencies 2024-12-13 19:41:52 +08:00
790e56d106 msys2: update 2024-12-13 19:39:39 +08:00
a00eb7969e [core] codegen: Implement matrix_power
Last of the functions that need to be ported over to strided-ndarray.
2024-12-13 15:23:31 +08:00
27a6f47330 [core] codegen: Implement construction of unsized ndarrays
Partially based on f731e604: core/ndstrides: add more ScalarOrNDArray
and NDArrayObject utils.
2024-12-13 15:23:31 +08:00
061747c67b [core] codegen: Implement NDArrayValue::atleast_nd
Based on 9cfa2622: core/ndstrides: add NDArrayObject::atleast_nd.
2024-12-13 15:23:31 +08:00
dc91d9e35a [core] codegen: Implement ScalarOrNDArray and use it in indexing
Based on 8f9d2d82: core/ndstrides: implement ndarray indexing.
2024-12-13 15:23:31 +08:00
438943ac6f [core] codegen: Implement indexing for NDArray
Based on 8f9d2d82: core/ndstrides: implement ndarray indexing

The functionality for `...` and `np.newaxis` is there in IRRT, but there
is no implementation of them for @kernel Python expressions because of
M-Labs/nac3#486.
2024-12-13 15:23:31 +08:00
678e56c95d [core] irrt: rename NDIndex to NDIndexInt
Unfortunately the name `NDIndex` is used in later commits. Renaming this
typedef to `NDIndexInt` to avoid amending. `NDIndexInt` will be removed
anyway when ndarray strides is completed.
2024-12-13 15:23:31 +08:00
fdfc80ca5f [core] codegen: Implement Slice{Type,Value}, RustSlice
Based on 01c96396: core/irrt: add Slice and Range and part of
8f9d2d82: core/ndstrides: implement ndarray indexing.

Needed for implementing general ndarray indexing.

Currently IRRT slice and range have nothing to do with NAC3's slice
and range. The IRRT slice and range are currently there to implement
ndarray specific features. However, in the future their definitions may
be used to replace that of NAC3's. (NAC3's range is a [i32 x 3], IRRT's
range is a proper struct. NAC3 does not have a slice struct).
2024-12-13 15:23:31 +08:00
8b3429d62a [artiq] Reimplement get_obj_value for strided ndarray
Based on 7ef93472: artiq: reimplement get_obj_value to use ndarray with
strides
2024-12-13 15:23:31 +08:00
f4c5038b95 [artiq] codegen: Reimplement polymorphic_print for strided ndarray
Based on 2a6ee503: artiq: reimplement polymorphic_print for ndarray
2024-12-13 15:23:31 +08:00
ddd16738a6 [core] codegen: implement ndarray iterator NDIter
Based on 50f960ab: core/ndstrides: implement ndarray iterator NDIter

A necessary utility to iterate through all elements in a possibly
strided ndarray.
2024-12-13 15:23:31 +08:00
44c49dc102 [artiq] codegen: Reimplement polymorphic_print for strided ndarray
Based on 2a6ee503: artiq: reimplement polymorphic_print for ndarray
2024-12-13 15:23:31 +08:00
e4bd376587 [core] codegen: Implement ContiguousNDArray
Fixes compatibility with linalg algorithms. matrix_power is missing due
to the need for indexing support.
2024-12-13 15:23:29 +08:00
44498f22f6 [core] codegen: Implement NDArray functions from a0a1f35b 2024-12-13 15:22:11 +08:00
110416d07a [core] codegen/irrt: Add IRRT functions for strided-ndarray 2024-12-13 15:22:11 +08:00
08a7d01a13 [core] Add itemsize and strides to NDArray struct
Temporarily disable linalg ndarray tests as they are not ported to work
with strided-ndarray.
2024-12-13 15:22:09 +08:00
3cd36fddc3 [core] codegen/types: Add check_struct_type_matches_fields
Shorthand for checking if a type is representable by a StructFields
instance.
2024-12-12 11:40:44 +08:00
56a7a9e03d [core] codegen: Add helper functions for create+call functions
Replacement for various FnCall methods from legacy ndstrides
implementation.
2024-12-12 11:30:36 +08:00
574ae40f97 [core] codegen: Add call_memcpy_generic_array
Replacement for Instance<Ptr>::copy_from from legacy ndstrides
implementation.
2024-12-12 11:30:36 +08:00
aa293b6bea [core] codegen: Add type_aligned_alloca 2024-12-12 11:30:35 +08:00
eb4b881690 [core] Expose {types,values}::ndarray modules
Allows better encapsulation of members in these modules rather than
allowing them to leak into types/values mod.
2024-12-12 11:30:14 +08:00
3d0a1d281c [core] Expose irrt::ndarray 2024-12-10 12:49:49 +08:00
ad67a99c8f [core] codegen: Cleanup builtin_fns.rs
- Unpack tuples directly in function argument
- Replace Vec parameters with slices
- Replace unwrap-transform with map-unwrap
2024-12-10 12:49:49 +08:00
8e2b50df21 [core] codegen/ndarray: Cleanup
- Remove redundant size param
- Add *_fields functions and docs
2024-12-09 13:01:08 +08:00
06092ad29b [core] Move alloca and map_value of ProxyType to implementations
These functions may not be invokable by the same set of parameters as
some classes has associated states.
2024-12-09 12:51:50 +08:00
d62c6b95fd [core] codegen/types: Rename StructField::set_from_value 2024-12-09 12:51:50 +08:00
95e29d9997 [core] codegen: Move ndarray type/value as a separate module 2024-12-09 12:51:46 +08:00
536ed2146c [meta] Remove all mentions of build_int_cast
build_int_cast performs signed extension or truncation depending on the
source and target int lengths. This is usually not what we want - We
want zero-extension instead.

Replace all instances of build_int_cast with
build_int_z_extend_or_bit_cast to fix this issue.
2024-12-09 12:51:39 +08:00
d484d44d95 [standalone] linalg: Fix function name in error message 2024-12-09 12:09:57 +08:00
ac978864f2 [meta] Apply clippy suggestions 2024-12-09 12:08:41 +08:00
95254f8464 [meta] Update Cargo dependencies 2024-12-09 12:08:41 +08:00
964945d244 string_store: update embedding map after compilation 2024-12-03 16:45:46 +08:00
ae09a0d444 exceptions: preallocate in NAC3 instead 2024-12-03 16:45:05 +08:00
01edd5af67 [meta] Apply rustfmt changes 2024-11-29 15:43:34 +08:00
015714eee1 copy constructor -> clone 2024-11-28 18:52:53 +08:00
71dec251e3 ld/dwarf: remove reader resets
DWARF reader never had to reverse. Readers are already copied to achieve this effect.
Plus the position that it reverses to might be questionable.
2024-11-28 18:52:53 +08:00
fce61f7b8c ld: fix dwarf sections offset calculations 2024-11-28 18:52:53 +08:00
babc081dbd core/toplevel: update tests 2024-11-27 14:31:57 +08:00
5337dbe23b core/toplevel: add python-like error messages for class definition 2024-11-27 14:31:57 +08:00
f862c01412 core/toplevel: refactor composer 2024-11-27 14:31:53 +08:00
0c9705f5f1 [meta] Apply clippy changes 2024-11-25 16:05:12 +08:00
5f940f86d9 [artiq] Fix obtaining ndarray struct from NDArrayType 2024-11-25 15:01:39 +08:00
5651e00688 flake: add platformdirs artiq dependency 2024-11-22 20:30:30 +08:00
f6745b987f bump sipyco and artiq used for profiling 2024-11-22 19:43:03 +08:00
e0dedc6580 nac3artiq: support kernels sent by content 2024-11-22 19:38:52 +08:00
28f574282c [core_derive] Ignore doctest in example
Causes linker errors for unknown reasons.
2024-11-22 00:00:05 +08:00
144f0922db [core] coregen/types: Implement StructFields for NDArray
Also rename some fields to better align with their naming in numpy.
2024-11-21 14:27:00 +08:00
c58ce9c3a9 [core] codegen/types: Implement NDArray in terms of i8*
Better aligns with the future implementation of ndstrides.
2024-11-21 14:27:00 +08:00
f7e296da53 [core] irrt: Break IRRT into several impl files
Each IRRT file is now mapped to one Rust file.
2024-11-21 14:27:00 +08:00
b58c99369e [core] irrt: Update some IRRT implementation
- Change CSlice to use `void*` for better pointer compatibility
- Only include impl *.hpp files in irrt.cpp
- Refactor typedef to using declaration
- Add missing ``// namespace`
2024-11-21 14:26:58 +08:00
1a535db558 [core] codegen: Add dtype to NDArrayType
We won't have this once NDArray is refactored to strided impl.
2024-11-20 15:35:57 +08:00
1ba2e287a6 [core] codegen: Add Self::llvm_type to all type abstractions 2024-11-20 15:35:57 +08:00
f95f979ad3 core/irrt: fix exception.hpp C++ castings 2024-11-20 15:35:57 +08:00
48e2148c0f core/toplevel/helper: add {extract,create}_ndims 2024-11-20 15:35:57 +08:00
88e57f7120 [core_derive] Initial implementation 2024-11-20 15:35:55 +08:00
d7633c42bc [core] codegen/types: Implement StructField{,s}
Loosely based on FieldTraversal by lyken.
2024-11-19 13:46:25 +08:00
a4f53b6e6b [core] codegen: Refactor ProxyType and ProxyValue
Accepts generator+context object for generic type checking. Also
implements more default trait impl for easier delegation.
2024-11-19 13:46:25 +08:00
9d9ead211e [core] Move Proxies to their own modules 2024-11-19 13:46:23 +08:00
26a1b85206 [core] codegen/classes: Remove Underlying type
This is confusing and we want a better abstraction than this.
2024-11-19 13:45:55 +08:00
2822074b2d [meta] Cleanup from upgrading Rust version
- Remove rust_2024_edition warnings, since it wouldn't be released for
another 3 months
- Fix new clippy warnings
2024-11-19 13:43:57 +08:00
fe67ed076c [meta] Update pre-commit configuration 2024-11-19 13:20:27 +08:00
94e2414df0 [meta] Update cargo dependencies 2024-11-19 13:20:26 +08:00
2cee760404 turn rust_2024_compatibility lints into warnings 2024-11-16 13:41:49 +08:00
230982dc84 update dependencies 2024-11-16 12:40:11 +08:00
2bd3f63991 boolop: terminate both branches with *_end_bb 2024-11-16 12:06:20 +08:00
b53266e9e6 artiq: use async RPC for attributes writeback 2024-11-12 12:04:01 +08:00
86eb22bbf3 artiq: main is always the last module 2024-11-12 12:03:38 +08:00
beaa38047d artiq: suppress main module debug warning 2024-11-12 12:03:08 +08:00
705dc4ff1c artiq: lump return value into attributes writeback RPC 2024-11-12 12:02:35 +08:00
979209a526 binop: expand not operator as loglcal not 2024-11-08 17:12:01 +08:00
c3927d0ef6 [ast] Refactor lazy_static to LazyLock
It is available in Rust 1.80 and reduces a dependency.
2024-10-30 12:29:51 +08:00
202a902cd0 [meta] Update dependencies 2024-10-30 12:29:51 +08:00
b6e2644391 [meta] Update cargo dependencies 2024-10-18 14:17:16 +08:00
45cd01556b [meta] Apply cargo fmt 2024-10-18 14:16:42 +08:00
b6cd2a6993 [meta] Reorganize order of use declarations - Phase 3 2024-10-17 16:25:52 +08:00
a98f33e6d1 [meta] Reorganize order of use declarations - Phase 2
Some more rules:

- For module-level imports, prefer no prefix > super > crate.
- Use crate instead of super if super refers to the crate-level module
2024-10-17 15:57:33 +08:00
5839badadd [standalone] Update globals.py with type-inferred global var 2024-10-07 20:44:08 +08:00
56c845aac4 [standalone] Add support for registering globals without type decl 2024-10-07 20:44:06 +08:00
65a12d9ab3 [core] Refactor registration of top-level variables 2024-10-07 17:05:48 +08:00
9c6685fa8f [core] typecheck/function_check: Fix lookup of defined ids in scope 2024-10-07 16:51:37 +08:00
2bb788e4bb [core] codegen/expr: Materialize implicit globals
Required for when globals are read without the use of a global
declaration.
2024-10-07 13:13:20 +08:00
42a2f243b5 [core] typecheck: Disallow redeclaration of var shadowing global 2024-10-07 13:11:00 +08:00
3ce2eddcdc [core] typecheck/type_inferencer: Infer whether variables are global 2024-10-07 13:10:46 +08:00
51bf126a32 [core] typecheck/type_inferencer: Differentiate global symbols
Required for analyzing use of global symbols before global declaration.
2024-10-07 12:25:00 +08:00
1a197c67f6 [core] toplevel/composer: Reduce lock scope while analyzing function 2024-10-05 15:53:20 +08:00
581b2f7bb2 [standalone] Add demo for global variables 2024-10-04 13:24:30 +08:00
746329ec5d [standalone] Implement symbol resolution for globals 2024-10-04 13:24:30 +08:00
e60e8e837f [core] Add support for global statements 2024-10-04 13:24:27 +08:00
9fdbe9695d [core] Add generator to SymbolResolver::get_symbol_value
Needed in a future commit.
2024-10-04 13:20:29 +08:00
8065e73598 [core] toplevel/composer: Add type analysis for global variables 2024-10-04 13:20:29 +08:00
192290889b [core] Add IdentifierInfo
Keeps track of whether an identifier refers to a global or local
variable.
2024-10-04 13:20:24 +08:00
1407553a2f [core] Implement parsing of global variables
Globals are now parsed into symbol resolver and top level definitions.
2024-10-04 13:18:29 +08:00
c7697606e1 [core] Add TopLevelDef::Variable 2024-10-04 13:09:25 +08:00
88d0ccbf69 [standalone] Explicit panic when encountering a compilation error
Otherwise scripts will continue to execute.
2024-10-04 13:00:16 +08:00
a43b59539c [meta] Move variables declarations closer to where they are first used 2024-10-04 13:00:16 +08:00
fe06b2806f [meta] Reorganize order of use declarations
Use declarations are now grouped into 4 groups:

- Declarations from the standard library
- Declarations from external crates
- Declarations from other crates in this project
- Declarations from within this module

Furthermore, all use declarations are grouped together to enhance
readability. super::super is also replaced by an equivalent crate::
declaration.
2024-10-04 12:52:01 +08:00
7f6c9a25ac [meta] Update Cargo dependencies 2024-10-04 12:52:01 +08:00
6c8382219f msys2: get python via numpy dependencies 2024-09-30 14:27:30 +08:00
9274a7b96b flake: update nixpkgs 2024-09-30 14:22:40 +08:00
d1c0fe2900 cargo: update dependencies 2024-09-30 14:14:43 +08:00
f2c047ba57 artiq: support async rpcs
Co-authored-by: mwojcik <mw@m-labs.hk>
Co-committed-by: mwojcik <mw@m-labs.hk>
2024-09-13 12:12:13 +08:00
5e2e77a500 [meta] Bump inkwell to v0.5 2024-09-13 11:11:14 +08:00
f3cc4702b9 [meta] Update dependencies 2024-09-13 11:11:14 +08:00
3e92c491f5 [standalone] Add tests creating ndarrays with tuple dims 2024-09-11 15:52:43 +08:00
7f629f1579 core: fix comment in unify_call 2024-09-11 15:46:19 +08:00
5640a793e2 core: allow np_full to take tuple shapes 2024-09-11 15:46:19 +08:00
abbaa506ad [standalone] Remove redundant recreation of TargetMachine 2024-09-09 14:27:10 +08:00
f3dc02d646 [meta] Apply cargo fmt 2024-09-09 14:24:52 +08:00
ea217eaea1 [meta] Update pre-commit config
Directly invoke cargo using nix develop to avoid using the system cargo.
2024-09-09 14:24:38 +08:00
5a34551905 allow the use of the LLVM shared library
Which in turns allows working around the incompatibility of the LLVM static library
with Rust link-args=-rdynamic, which produces binaries that either fail to link (OpenBSD)
or segfault on startup (Linux).

The year is 2024 and compiler toolchains are still a trash fire like this.
2024-09-09 11:17:31 +08:00
6098b1b853 fix previous commit 2024-09-06 11:32:08 +08:00
668ccb1c95 nac3core: expose inkwell and nac3parser 2024-09-06 11:06:26 +08:00
a3c624d69d update all dependencies 2024-09-06 10:21:58 +08:00
bd06155f34 irrt: compatibility with pre-C23 compilers 2024-09-05 18:54:55 +08:00
9c33c4209c [core] Fix type of ndarray.element_type
Should be the element type of the NDArray itself, not the pointer to its
type.
2024-08-30 22:47:38 +08:00
122983f11c flake: update dependencies 2024-08-30 14:45:38 +08:00
71c3a65a31 [core] codegen/stmt: Fix obtaining return type of sret functions 2024-08-29 19:15:30 +08:00
8c540d1033 [core] codegen/stmt: Add more casts for boolean types 2024-08-29 16:36:32 +08:00
0cc60a3d33 [core] codegen/expr: Fix missing cast to i1 2024-08-29 16:36:32 +08:00
a59c26aa99 [artiq] Fix RPC of ndarrays from host 2024-08-29 16:08:45 +08:00
02d93b11d1 [meta] Update dependencies 2024-08-29 14:32:21 +08:00
59cad5bfe1
standalone: clang-format demo.c 2024-08-29 10:37:24 +08:00
4318f8de84
standalone: improve src/assignment.py 2024-08-29 10:33:58 +08:00
15ac00708a [core] Use quoted include paths instead of angled brackets
This is preferred for user-defined headers.
2024-08-28 16:37:03 +08:00
c8dfdcfdea
standalone & artiq: remove class_names from resolver 2024-08-27 23:43:40 +08:00
600a5c8679 Revert "standalone: reformat demo.c"
This reverts commit 308edb8237.
2024-08-27 23:06:49 +08:00
22c4d25802 core/typecheck: add missing typecheck in matmul 2024-08-27 22:59:39 +08:00
308edb8237 standalone: reformat demo.c 2024-08-27 22:55:22 +08:00
9848795dcc core/irrt: add exceptions and debug utils 2024-08-27 22:55:22 +08:00
58222feed4 core/irrt: split into headers 2024-08-27 22:55:22 +08:00
518f21d174 core/irrt: build.rs capture IR defined constants 2024-08-27 22:55:22 +08:00
e8e49684bf core/irrt: build.rs capture IR defined types 2024-08-27 22:55:22 +08:00
b2900b4883 core/irrt: use +std=c++20 to compile
To explicitly set the C++ variant and avoid inconsistencies.
2024-08-27 22:55:22 +08:00
c6dade1394 core/irrt: reformat 2024-08-27 22:55:22 +08:00
7e3fcc0845 add .clang-format 2024-08-27 22:55:22 +08:00
d3b4c60d7f core/irrt: comment build.rs & move irrt to nac3core/irrt 2024-08-27 22:55:22 +08:00
5b2b6db7ed core: improve error messages 2024-08-26 18:37:55 +08:00
15e62f467e standalone: add tests for polymorphism 2024-08-26 18:37:55 +08:00
2c88924ff7 core: add support for simple polymorphism 2024-08-26 18:37:55 +08:00
a744b139ba core: allow Call and AnnAssign in init block 2024-08-26 18:37:55 +08:00
2b2b2dbf8f [core] Fix resolution of exception names in raise short form
Previous implementation fails as `resolver.get_identifier_def` in ARTIQ
would return the exception __init__ function rather than the class.

We fix this by limiting the exception class resolution to only include
raise statements, and to force the exception name to always be treated
as a class.

Fixes #501.
2024-08-26 18:35:02 +08:00
d9f96dab33 [core] Add codegen_unreachable 2024-08-23 13:10:55 +08:00
c5ae0e7c36 [standalone] Add tests for tuple equality 2024-08-21 16:25:32 +08:00
b8dab6cf7c [standalone] Add tests for string equality 2024-08-21 16:25:32 +08:00
4d80ba38b7 [core] codegen/expr: Implement comparison of tuples 2024-08-21 16:25:32 +08:00
33929bda24 [core] typecheck/typedef: Add support for tuple methods 2024-08-21 16:25:32 +08:00
a8e92212c0 [core] codegen/expr: Implement string equality 2024-08-21 16:25:32 +08:00
908271014a [core] typecheck/magic_methods: Add equality methods to string 2024-08-21 16:25:32 +08:00
c407622f5c [core] codegen/expr: Add compilation error for unsupported cmpop 2024-08-21 15:46:13 +08:00
d7952d0629 [core] codegen/expr: Fix assertions not generated for -O0 2024-08-21 15:36:54 +08:00
ca1395aed6 [core] codegen: Remove redundant return 2024-08-21 15:36:54 +08:00
7799aa4987 [meta] Do not specify rev in dependency version 2024-08-21 15:36:54 +08:00
76016a26ad [meta] Apply clippy suggestions 2024-08-21 13:07:57 +08:00
8532bf5206
standalone: add missing test_ndarray_ceil() run 2024-08-21 11:39:00 +08:00
2cf64d8608
apply clippy comment changes 2024-08-21 11:21:10 +08:00
706759adb2
artiq: apply cargo fmt 2024-08-21 11:21:10 +08:00
b90cf2300b
core/fix: add missing lifetime in gen_for* 2024-08-21 11:05:30 +08:00
0fc26df29e flake: update nixpkgs 2024-08-19 23:53:15 +08:00
0b074c2cf2 [artiq] symbol_resolver: Set private linkage for constants 2024-08-19 14:41:43 +08:00
a0f6961e0e cargo: update dependencies 2024-08-19 13:15:03 +08:00
b1c5c2e1d4 [artiq] Fix RPC of ndarrays to host 2024-08-15 15:41:24 +08:00
69320a6cf1 [artiq] Fix LLVM representation of strings
Should be `%str` rather than `[N x i8]`.
2024-08-14 09:30:08 +08:00
9e0601837a core: Add compile-time feature to disable escape analysis 2024-08-14 09:29:48 +08:00
432c81a500
core: update insta after #489 2024-08-13 15:30:34 +08:00
6beff7a268 [artiq] Implement core_log and rtio_log in terms of polymorphic_print
Implementation mostly references the original implementation in Python.
2024-08-13 15:19:03 +08:00
6ca7aecd4a [artiq] Add core_log and rtio_log function declarations 2024-08-13 15:19:03 +08:00
8fd7216243 [core] toplevel/composer: Add lateinit_builtins
This is required for the new core_log and rtio_log functions, which take
a generic type as its parameter. However, in ARTIQ builtins are
initialized using one unifier and then actually used by another unifier.

lateinit_builtins workaround this issue by deferring the initialization
of functions requiring type variables until the actual unifier is ready.
2024-08-13 15:19:03 +08:00
4f5e417012 [core] codegen: Add function to get format constants for integers 2024-08-13 15:19:03 +08:00
a0614bad83 [core] codegen/expr: Make gen_string return StructValue
So that it is clear that the value itself is returned rather than a
pointer to the struct or its data.
2024-08-13 15:19:03 +08:00
5539d144ed [core] Add CodeGenContext::build_in_bounds_gep_and_load
For safer accesses to `gep`-able values and faster fails.
2024-08-13 15:19:03 +08:00
b3891b9a0d standalone: Fix several issues post script refactoring
- Add helptext for check_demos.sh
- Add back support for using debug NAC3 for running tests
- Output error message when argument is not recognized
- Fixed last non-demo script argument being ignored
- Add back SSE2 requirement to NAC3 (required for mandelbrot)
2024-08-13 15:19:03 +08:00
6fb8939179 [meta] Update dependencies 2024-08-13 15:19:03 +08:00
973dc5041a core/typecheck: Support tuple arg type in len() 2024-08-13 15:02:59 +08:00
d0da688aa7 standalone: Add tuple len test 2024-08-13 15:02:59 +08:00
12c4e1cf48 core/toplevel/builtins: Add support for len() on tuples 2024-08-13 15:02:59 +08:00
9b988647ed core/toplevel/builtins: Extract len() into builtin function 2024-08-13 15:02:59 +08:00
35a7cecc12
core/typecheck: fix np_array ndmin bug 2024-08-13 12:50:04 +08:00
7e3d87f841 core/codegen: fix bug in call_ceil function 2024-08-07 16:40:55 +08:00
ac0d83ef98 standalone: Add vararg.py 2024-08-06 11:48:42 +08:00
3ff6db1a29 core/codegen: Add va_start and va_end intrinsics 2024-08-06 11:48:42 +08:00
d7b806afb4 core/codegen: Implement support for va_info on supported architectures 2024-08-06 11:48:40 +08:00
fac60c3974 core/codegen: Handle vararg in function generation 2024-08-06 11:46:00 +08:00
f5fb504a15 core/codegen/expr: Implement vararg handling in gen_call 2024-08-06 11:46:00 +08:00
faa3bb97ad core/typecheck/typedef: Add vararg to Unifier::stringify 2024-08-06 11:46:00 +08:00
6a64c9d1de core/typecheck/typedef: Add is_vararg_ctx to TTuple 2024-08-06 11:45:54 +08:00
3dc8498202 core/typecheck/typedef: Handle vararg parameters in unify_call 2024-08-06 11:43:13 +08:00
cbf79c5e9c core/typecheck/typedef: Add is_vararg to FuncArg, ConcreteFuncArg 2024-08-06 11:43:13 +08:00
b8aa17bf8c core/toplevel/composer: Add parsing for vararg parameter 2024-08-06 10:52:24 +08:00
f5b998cd9c core/codegen: Remove unnecessary mut from get_llvm*_type 2024-08-06 10:52:24 +08:00
c36f85ecb9 meta: Update dependencies 2024-08-06 10:52:24 +08:00
3a8c385e01 core/typecheck: fix missing ExprKind::Asterisk in fix_assignment_target_context 2024-08-05 19:30:48 +08:00
221de4d06a core/codegen: add missing comment 2024-08-05 19:30:48 +08:00
fb9fe8edf2 core: reimplement assignment type inference and codegen
- distinguish between setitem and getitem
- allow starred assignment targets, but the assigned value would be a tuple
- allow both [...] and (...) to be target lists
2024-08-05 19:30:48 +08:00
894083c6a3 core/codegen: refactor gen_{for,comprehension} to match on iter type 2024-08-05 19:30:48 +08:00
669c6aca6b clean up and fix 32-bit demos 2024-08-05 19:04:25 +08:00
63d2b49b09 core: remove np_linalg_matmul 2024-08-05 11:44:55 +08:00
bf709889c4 standalone/demo: separate linalg functions from main workspace 2024-08-05 11:44:54 +08:00
1c72698d02 core: add np_linalg_det and np_linalg_matrix_power functions 2024-07-31 18:02:54 +08:00
54f883f0a5 core: implement np_dot using LLVM_IR 2024-07-31 15:53:51 +08:00
4a6845dac6 standalone: add np.transpose and np.reshape functions 2024-07-31 13:23:07 +08:00
00236f48bc core: add np.transpose and np.reshape functions 2024-07-31 13:23:07 +08:00
a3e6bb2292 core/helper: add linalg section 2024-07-31 13:23:07 +08:00
17171065b1 standalone: link linalg at runtime 2024-07-31 13:23:07 +08:00
540b35ec84 standalone: move linalg functions to demo 2024-07-31 13:23:05 +08:00
4bb00c52e3 core/builtin_fns: improve error reporting 2024-07-31 13:21:31 +08:00
faf07527cb standalone: add runtime implementation for linalg functions 2024-07-31 13:21:28 +08:00
d6a4d0a634 standalone: add linalg methods and tests 2024-07-29 16:48:06 +08:00
2242c5af43 core: add linalg methods 2024-07-29 16:48:06 +08:00
318a675ea6 standalone: Rename -m32 to -i386 2024-07-29 14:58:58 +08:00
32e52ce198 standalone: Revert using uint32_t as slice length
Turns out list and str have always been size_t.
2024-07-29 14:58:29 +08:00
665ca8e32d cargo: update dependencies 2024-07-27 22:24:56 +08:00
12c12b1d80 flake: update nixpkgs 2024-07-27 22:22:20 +08:00
72972fa909 core/toplevel: add more numpy categories 2024-07-27 21:57:47 +08:00
142cd48594 core/toplevel: reorder PrimDef::details 2024-07-27 21:57:47 +08:00
8adfe781c5 core/toplevel: fix PrimDef method names 2024-07-27 21:57:47 +08:00
339b74161b core/toplevel: reorganize PrimDef 2024-07-27 21:57:47 +08:00
8c5ba37d09 standalone: Add 32-bit execution tests to check_demo.sh 2024-07-26 13:35:40 +08:00
05a8948ff2 core: Minor cleanup to use ListValue APIs 2024-07-26 13:35:40 +08:00
6d171ec284 core: Add label name and hooks to gen_for functions 2024-07-26 13:35:40 +08:00
0ba68f6657 core: Set target triple and datalayout for each module
Fixes an issue with inconsistent pointer sizes causing crashes.
2024-07-26 13:35:40 +08:00
693b2a8863 core: Add support for 32-bit size_t on 64-bit targets 2024-07-26 13:35:40 +08:00
5faeede0e5 Determine size_t using LLVM target machine 2024-07-26 13:35:38 +08:00
266707df9d standalone: Add support for running 32-bit binaries 2024-07-26 13:32:38 +08:00
3d3c258756 standalone: Remove support for --lli 2024-07-26 13:32:38 +08:00
ed1182cb24 standalone: Update format specifiers for exceptions
Use platform-agnostic identifiers instead.
2024-07-26 13:32:37 +08:00
fd025c1137 standalone: Use uint32_t for cslice length
Matching the expected type of string and list slices.
2024-07-26 13:32:21 +08:00
f139db9af9 meta: Update dependencies 2024-07-26 10:33:02 +08:00
44487b76ae standalone: interpret_demo.py remove duplicated section 2024-07-22 17:23:35 +08:00
1332f113e8 standalone: fix interpret_demo.py comments 2024-07-22 17:06:14 +08:00
7632d6f72a cargo: update dependencies 2024-07-21 11:00:25 +08:00
4948395ca2 core/toplevel/type_annotation: Add handling for mismatching class def
Primitive types only contain fields in its Type and not its TopLevelDef.
This causes primitive object types to lack some fields.
2024-07-19 14:42:14 +08:00
3db3061d99 artiq/symbol_resolver: Handle type of zero-length lists 2024-07-19 14:42:14 +08:00
51c2175c80 core/codegen/stmt: Convert assertion values to i1 2024-07-19 14:42:14 +08:00
1a31a50b8a
standalone: fix __nac3_raise def in demo.c 2024-07-17 21:22:08 +08:00
6c10e3d056 core: cargo clippy 2024-07-12 21:18:53 +08:00
2dbc1ec659 cargo fmt 2024-07-12 21:16:38 +08:00
c80378063a add np_argmin/argmax to interpret_demo environment 2024-07-12 13:27:52 +02:00
513d30152b core: support raise exception short form 2024-07-12 18:58:34 +08:00
45e9360c4d standalone: Add np_argmax and np_argmin tests 2024-07-12 18:19:56 +08:00
2e01b77fc8 core: refactor np_max/np_min functions 2024-07-12 18:18:54 +08:00
cea7cade51 core: add np_argmax/np_argmin functions 2024-07-12 18:18:28 +08:00
d658d9b00e update dependencies, Python 3.12 on Linux 2024-07-09 23:56:12 +08:00
eeb474f9e6 core: reduce code duplication in codegen/extern_fns (#453)
Used macros to reduce code duplication in `codegen/extern_fns`

Reviewed-on: M-Labs/nac3#453
Co-authored-by: abdul124 <ar@m-labs.hk>
Co-committed-by: abdul124 <ar@m-labs.hk>
2024-07-09 16:31:08 +08:00
88b72af2d1 core/llvm_intrinsic: improve macro name and comments 2024-07-09 16:30:32 +08:00
b73f6c4d68 core: reduce code duplication in codegen/llvm_intrinsic 2024-07-09 16:30:32 +08:00
f47cdec650 standalone: Fix output format of output_range 2024-07-09 13:55:48 +08:00
d656880e44 standalone: Fix missing implementation for output_range 2024-07-09 13:53:50 +08:00
a91602915a core: Fix missing fields in range type 2024-07-09 13:53:50 +08:00
1c56005a01 core: Reformat and modernize irrt.cpp
- Use anon namespace instead of static
- Use using declaration instead of typedef
- Align pointers to the type instead of the identifier
2024-07-09 13:53:50 +08:00
bc40a32524 core: Add report_type_error to enable more code reuse 2024-07-09 13:44:47 +08:00
c820daf5f8 core: Apply cargo format 2024-07-09 13:32:10 +08:00
25d2de67f7 standalone: Add output_range and tests 2024-07-09 04:44:40 +08:00
2cfb7a7e10 core: Refactor range function into constructor 2024-07-09 04:44:40 +08:00
9238a5e86e standalone: Rename output_str to output_strln and add output_str
output_str is for outputting strings without newline, and the newly
introduced output_strln now has the old behavior of ending with a
newline.
2024-07-09 04:44:40 +08:00
76defac462 meta: use clang -x c++ instead of clang++ 2024-07-07 20:03:34 +08:00
650f354b74 core: use C++ for irrt source 2024-07-07 14:36:10 +08:00
f062ef5f59 core/llvm_intrinsic: replace roundeven with rint 2024-07-07 14:24:18 +08:00
f52086b706 core: improve binop and cmpop error messages 2024-07-05 16:27:24 +08:00
0a732691c9 core: refactor typecheck/magic_methods.rs operators & add op symbol name 2024-07-05 16:27:20 +08:00
cbff356d50 core: workaround inkwell on llvm.stackrestore 2024-07-05 13:56:12 +08:00
24ac3820b2 core: check int32 obj_id directly in fold_numpy_function_call_shape_argument 2024-07-05 10:36:47 +08:00
ba32fab374 standalone: Add demos for list arithmetic operators 2024-07-04 16:01:15 +08:00
c4052b6342 core: Implement multi-operand __eq__ and __ne__ for lists 2024-07-04 16:01:15 +08:00
66c205275f core: Implement list::__add__ 2024-07-04 16:01:11 +08:00
c85e412206 core: Implement list::__mul__ 2024-07-04 15:53:50 +08:00
075536d7bd core: Add BreakContinueHooks for gen_for_callback 2024-07-04 15:32:18 +08:00
13beeaa2bf core: Implement handling for zero-length lists 2024-07-04 15:32:18 +08:00
2194dbddd5 core/type_annotation: Refactor List type to TObj
In preparation for operators on lists.
2024-07-04 15:32:18 +08:00
94a1d547d6 meta: Update dependencies 2024-07-04 15:32:18 +08:00
d6565feed3 core: ndarray_from_ndlist_impl cast size_of to usize 2024-07-04 12:24:52 +08:00
83154ef8e1 core/llvm_intrinsics: remove llvm.roundeven call from call_float_roundeven 2024-07-03 14:17:47 +08:00
0744b938b8 core: fix __nac3_ndarray_calc_size crash due to incorrect typing 2024-07-03 13:03:14 +08:00
56fa2b6803 core: fix crash on iterating over non-iterables
a
2024-06-28 15:45:53 +08:00
d06c13f936 core: fix crash on invalid subscripting 2024-06-27 16:58:48 +08:00
9808923258 core: improve comments in type_inferencer/mod.rs 2024-06-27 14:46:48 +08:00
5b11a1dbdd core: support tuple and int32 input for np_empty, np_ones, and more 2024-06-27 14:30:17 +08:00
b21df53e0d core: fix comment typo in unify_call() 2024-06-27 14:06:39 +08:00
0ec967a468 core: improve function call errors 2024-06-27 14:06:39 +08:00
ca8459dc7b standalone: prettify TopLevelComposer error reporting 2024-06-27 10:15:14 +08:00
b0b804051a nac3artiq: allow class attribute access without init function 2024-06-25 16:06:33 +08:00
134af79fd6 core: add support for class attributes 2024-06-25 16:06:33 +08:00
7fe2c3496c core: add attribute field to class definition 2024-06-25 16:06:33 +08:00
144a3fc426 core: more derive Debug in typedef 2024-06-25 15:02:50 +08:00
74096eb9f6 core: name codegen worker threads 2024-06-25 12:36:37 +08:00
06e9d90d57 apply clippy changes 2024-06-21 14:14:01 +08:00
d89146aa02 core: use no_run on builtin_fns docs 2024-06-20 13:53:25 +08:00
5bade81ddb standalone: Add test for multidim array index with one index 2024-06-20 12:50:30 +08:00
0452e6de78 core: Fix codegen for tuple-index into ndarray 2024-06-20 12:50:30 +08:00
635c944c90 core: Fix type inference for tuple-index into ndarray
Fixes #420.
2024-06-20 12:50:30 +08:00
e36af3b0a3 core: reduce code duplication in codegen/builtin_fns (#422)
Used macros to generate some unary math functions.

Reviewed-on: M-Labs/nac3#422
Reviewed-by: David Mak <chmakac@connect.ust.hk>
Co-authored-by: lyken <lyken@m-labs.hk>
Co-committed-by: lyken <lyken@m-labs.hk>
2024-06-20 12:48:44 +08:00
5b1aa812ed update dependencies 2024-06-20 10:43:55 +08:00
d3cd2a8d99 artiq: Add support for generating RPC tag for ndarray 2024-06-19 18:56:16 +08:00
202a63274d artiq: Implement pyty-to-ty conversion 2024-06-19 18:56:15 +08:00
76dd5191f5 artiq: Implement Python-to-LLVM conversion of ndarray 2024-06-19 18:56:15 +08:00
8d9df0a615 artiq: Fix ndarray class ID
We want the class ID of the ndarray class, not its corresponding typing
class.
2024-06-19 18:56:15 +08:00
07adfb2a18 standalone: Add *.ll to Gitignore list 2024-06-19 18:56:15 +08:00
f00e458f60 add test for class without __init__ 2024-06-19 18:16:54 +08:00
1bc95a7ba6 Add handling for np.bool_ and np.str_ 2024-06-19 15:10:47 +08:00
e85f4f9bd2 core: refactor top_level::builtins::get_builtins() 2024-06-18 11:06:25 +08:00
ce3e9bf4fe nac3artiq: add support string attributes in classes 2024-06-17 16:53:51 +08:00
82091b1be8 meta: Apply clippy changes 2024-06-17 14:10:31 +08:00
32919949e2 Run clippy --tests on pre-commit hook 2024-06-17 12:51:25 +08:00
2abe75d1f4 core: remove code dup with make_exception_fields 2024-06-17 12:01:48 +08:00
676412fe6d apply cargo fmt 2024-06-14 09:46:42 +08:00
8b9df7252f core: cleanup with Unifier::generate_var_id 2024-06-14 09:42:04 +08:00
6979843431 core: fix typo in into_var_map 2024-06-13 16:59:10 +08:00
fed1361c6a core: rename to_var_map to into_var_map 2024-06-13 16:59:10 +08:00
aa94e0c8a4 core: remove pub & add From<TypeVarId> for u32 2024-06-13 16:59:10 +08:00
f523e26227 core: fix typo in fmt::Display of TypeVarId 2024-06-13 16:59:10 +08:00
f026b48e2a core: refactor to use TypeVarId and TypeVar 2024-06-13 16:59:10 +08:00
dc874f2994 core: use PrimDef simple names in make_primitives() 2024-06-13 16:58:32 +08:00
95de0800b4 core/demo: fix typo in .gitignore 2024-06-13 16:05:33 +08:00
3d71c6a850 core/demo: gitignore to ignore *.bc & *.o 2024-06-13 16:00:23 +08:00
be55e2ac80 meta: Update README to include info regarding pre-commit hooks 2024-06-12 16:10:57 +08:00
79c8b759ad meta: Add pre-commit configuration 2024-06-12 16:10:57 +08:00
4798c53a21 flake: Add pre-commit to dev environment 2024-06-12 16:10:57 +08:00
23974feae7 meta: Restrict number of allowed lints 2024-06-12 16:10:57 +08:00
40a3bded36 meta: Set clippy lints in {main,lib}.rs
So that this does not have to be manually passed to the `cargo clippy`
command-line every single time. Also allows incrementally addressing
these lints by removing and fixing them one-by-one.
2024-06-12 16:10:57 +08:00
c4420e6ab9 core: refactor get_builtins() 2024-06-12 15:09:20 +08:00
fd36f78005 core: refactor PrimitiveDefinitionId into enum PrimDef 2024-06-12 15:01:01 +08:00
8168692cc3 apply cargo fmt 2024-06-12 14:45:03 +08:00
53d44b9595 standalone: Add np_array tests 2024-06-11 16:44:36 +08:00
6153f94b05 core/numpy: Implement codegen for np_array 2024-06-11 16:42:11 +08:00
4730b595f3 core/builtins: Add np_array function 2024-06-11 16:42:08 +08:00
c2fdb12397 core/type_inferencer: Add special rule for np_array 2024-06-11 16:40:35 +08:00
82bf14785b core: Add multidimensional array helpers 2024-06-11 15:30:06 +08:00
2d4329e23c core/stmt: Use BB of last statement in if-else in phi 2024-06-11 15:30:06 +08:00
679656f9e1 core/classes: Fix incorrect field locations for lists 2024-06-11 15:30:06 +08:00
210d9e2334 core: Add more creator functions for ProxyType 2024-06-11 15:26:37 +08:00
181ac3ec1a core/classes: Fix incorrect pointers of range.{stop,step} 2024-06-11 15:13:31 +08:00
3acdfb304d meta: Apply clippy suggestions 2024-06-11 14:58:32 +08:00
6e24da9cc5 meta: Update dependencies 2024-06-11 14:58:32 +08:00
f0ab1b858a core: Refactor class abstractions
- Introduce new Type abstractions
- Rearrange some functions
2024-06-06 13:45:51 +08:00
08129cc635 nac3core: add TopLevelComposer::new builtin check's assertion msg 2024-06-05 15:30:02 +08:00
ad4832dcf4 core: Refactor to get LLVM intrinsics via Intrinsics::find 2024-06-05 15:29:40 +08:00
520bbb246b flake: add llvmPackages_14.llvm to devShells linux default (#405)
Co-authored-by: lyken <lyken@m-labs.hk>
Co-committed-by: lyken <lyken@m-labs.hk>
2024-06-05 11:11:56 +08:00
b857f1e403 nac3core: fix typo in gen_for's comment 2024-06-04 17:15:41 +08:00
fa8af37e84 flake: update nixpkgs 2024-06-03 22:22:04 +08:00
23b2fee4e7 standalone: Add test case for ndarray slicing 2024-06-03 16:40:05 +08:00
ed79d5bb9e core/expr: Add support for multi-dim slicing of NDArrays 2024-06-03 16:40:05 +08:00
c35ad06949 core/expr: Add support for 1D slicing of NDArrays 2024-06-03 16:40:05 +08:00
135ef557f9 core/numpy: Implement ndarray_sliced_{copy,copyto_impl}
Performing copying with optional support for slicing. Also made
copy_impl delegate to sliced_copy, as sliced_copy now performs a
superset of operations that copy_impl can already do.
2024-06-03 16:40:05 +08:00
a176c3eb70 core/irrt: Change handle_slice_indices to instead take length of object
So that all other array-like datatypes (e.g. ndarray) can also take
advantage of it.
2024-06-03 16:40:05 +08:00
2cf79510c2 core/numpy: Add more helper functions 2024-06-03 16:40:05 +08:00
b6ff75dcaf core/irrt: Add support for calculating partial size of NDArray 2024-06-03 16:40:05 +08:00
588c15f80d core/stmt: Add gen_for_range_callback
For generating for loops over range objects or array slices.
2024-06-03 16:40:05 +08:00
82cc693b11 meta: Update dependencies 2024-06-03 16:40:02 +08:00
520e1adc56 core/builtins: Add np_minimum/np_maximum 2024-05-09 15:01:20 +08:00
73e81259f3 core/builtins: Add np_min/np_max 2024-05-09 15:01:20 +08:00
7627acea41 core/type_inferencer: Fix error message 2024-05-09 15:01:20 +08:00
a777099ea8 core/type_inferencer: Fix missing lowering for some builtin TVars 2024-05-09 15:01:20 +08:00
876e6ea7b8 meta: Update dependencies 2024-05-08 17:27:38 +08:00
30c6cffbad core/builtins: Refactored numpy builtins to accept scalar and ndarrays 2024-05-06 15:38:29 +08:00
51671800b6 core/builtins: Extract codegen portion into functions
We will need to reuse them when implementing elementwise function
application for ndarrays.
2024-05-06 13:21:54 +08:00
7195476edb core/builtins: Add llvm_intrinsics prefix 2024-05-06 13:21:54 +08:00
eecba0b71d core: Add GenCall::create_dummy
A simple abstraction for GenCalls that are already handled elsewhere.
2024-05-06 13:21:54 +08:00
7b4253ccd8 core/numpy: Add missing lifetime parameters 2024-05-06 13:21:54 +08:00
f58c3a11f8 core/builtins: Rework handling of PrimitiveStore-Unifier tuples 2024-05-06 13:21:54 +08:00
d0766a116f core: Remove Box from GenCallCallback type alias
So that references to the function type can be taken.
2024-05-06 13:21:54 +08:00
64a3751fc2 core: Remove custom function type definitions for ndarray operators 2024-05-06 13:21:54 +08:00
9566047241 standalone: Fix cbrt never tested 2024-05-06 13:21:54 +08:00
062e318dd5 core/magic_methods: Fix clippy warnings 2024-05-06 13:21:54 +08:00
c4dc36ae99 standalone: Add explicit -- for delimiting run args vs NAC3 args 2024-05-06 13:21:54 +08:00
baac348ee6 meta: Update dependencies 2024-05-06 13:21:37 +08:00
847615fc2f core: Implement numpy.matmul for 2D-2D ndarrays 2024-04-23 10:27:37 +08:00
5dfcc63978 core/classes: Take reference of indexes 2024-04-16 17:20:24 +08:00
025b3cd02f core/stmt: Remove gen_if_chained*
Turns out it is really difficult to get lifetimes and closures right, so
let's just provide the most rudimentary if-else codegen and we can nest
them if necessary.
2024-04-16 17:16:50 +08:00
e0f440040c core/expr: Implement negative indices for ndarray 2024-04-15 12:49:42 +08:00
f0715e2b6d core/stmt: Add gen_if* functions
For generating if-constructs in IR.
2024-04-15 12:20:34 +08:00
e7fca67786 core/stmt: Do not generate jumps if bb is already terminated
Future-proofs gen_*_callback functions in case other codegen functions
will delegate to it in the future.
2024-04-15 12:20:34 +08:00
52c731c312 core: Implement Not/UAdd/USub for booleans
Not sure if this is deliberate or an oversight, but we implement it
anyway for consistency with other Python implementations.
2024-04-12 18:29:58 +08:00
00d1b9be9b core: Fix __inv__ for i8-based boolean operands 2024-04-12 15:35:54 +08:00
8404d4c4dc meta: Update dependencies 2024-04-12 15:29:09 +08:00
e614dd4257 core/type_inferencer: Fix location of unary/compare expressions
Codegen uses this location information to determine the CallId, and if
a function call is the operand of a unary expression or left-hand
operand of a compare expression, codegen will use the type of the
operator expression rather than the actual operand type.
2024-04-05 15:42:10 +08:00
937a8b9698 core/magic_methods: Fix type of unary ops with primitive types 2024-04-05 13:23:08 +08:00
876ad6c59c core/type_inferencer: Include location info if inferencer fails 2024-04-05 13:22:35 +08:00
a920fe0501 core: Implement elementwise comparison operators 2024-04-03 00:07:33 +08:00
727a1886b3 core: Implement elementwise unary operators 2024-04-03 00:07:33 +08:00
6af13a8261 core: Implement elementwise binary operators
Including immediate variants of these operators.
2024-04-03 00:07:33 +08:00
3540d0ab29 core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with
primitive operands.
2024-04-03 00:07:33 +08:00
3a6c53d760 core/toplevel/numpy: Split ndarray type var utilities 2024-04-03 00:07:33 +08:00
87bc34f7ec core: Implement calculations for broadcasting ndarrays 2024-04-03 00:07:31 +08:00
f50a5f0345 core/type_inferencer: Allow both int32 and isize when indexing ndarray 2024-04-02 16:49:12 +08:00
a77fd213e0 core/magic_methods: Allow unknown return types
These types can be later inferred by the type inferencer.
2024-04-02 16:49:12 +08:00
8f1497df83 core/helper: Add PrimitiveDefinitionIds::iter 2024-04-02 16:49:12 +08:00
5ca2dbeec8 core/typedef: Add Type::obj_id to replace get_obj_id 2024-04-02 16:49:10 +08:00
9a98cde595 core: Extract codegen portion of gen_*op_expr
This allows *ops to be generated internally using LLVM values as
input. Required in a future change.
2024-04-01 16:48:25 +08:00
5ba8601b39 core: Remove ArrayValue variants of functions
These will be lowered and optimized away later anyways, and we have
ArrayLikeAccessor now.
2024-04-01 16:48:25 +08:00
26a01b14d5 core: Use more typed slices in APIs 2024-04-01 16:48:25 +08:00
d5f4817134 core/builtins: Fix len() on ndarrays 2024-04-01 16:48:24 +08:00
789bfb5a26 core: Fix index-based operations not returning i32 2024-04-01 16:46:45 +08:00
4bb0e60981 core: Apply clippy suggestions 2024-04-01 16:46:41 +08:00
623fcf85af msys2: update 2024-03-25 14:45:36 +08:00
13f06f3e29 core: Refactor VarMap to IndexMap
This is the only Map I can find that preserves insertion order while
also deduplicating elements by key.
2024-03-22 15:51:23 +08:00
f0da9c0283 core: Add ArrayLikeValue
For exposing LLVM values that can be accessed like an array.
2024-03-22 15:51:06 +08:00
2c4bf3ce59 core: Allow unsized CodeGenerator to be passed to some codegen functions
Enables codegen_callback to call these codegen functions as well.
2024-03-22 15:07:28 +08:00
e980f19c93 core: Simplify typed value assertions 2024-03-22 15:07:28 +08:00
cfbc37c1ed core: Add gen_for_callback_incrementing
Simplifies generation of monotonically increasing for loops.
2024-03-22 15:07:28 +08:00
50264e8750 core: Add missing unchecked accessors for NDArrayDimsProxy 2024-03-22 15:07:28 +08:00
1b77e62901 core: Split numpy into codegen and toplevel 2024-03-22 15:07:28 +08:00
fd44ee6887 core: Apply clippy suggestions 2024-03-22 15:07:23 +08:00
c8866b1534 core/classes: Rename get_* functions to remove prefix
As suggested by Rust API Guidelines.
2024-03-21 15:46:10 +08:00
84a888758a core: Rename unsafe functions to unchecked
This is this intended name of the functions.
2024-03-21 15:46:10 +08:00
9d550725b7 meta: Update cargo dependencies 2024-03-21 15:45:26 +08:00
2edc1de0b6 standalone: Update ndarray.py to output all elements in ndarrays 2024-03-07 14:59:13 +08:00
c3b122acfc core: Implement ndarray.copy 2024-03-07 14:59:13 +08:00
a94927a11d core: Update __builtin_assume expressions
No dimension size should be 0.
2024-03-07 14:59:13 +08:00
ebf86cd134 core: Use size_t for accessing array elements 2024-03-07 14:59:13 +08:00
cccd8f2d00 core: Fix ndarray_eye not preserving signness of offset 2024-03-07 14:59:13 +08:00
3292aed099 core: Fix ndarray subscript operator returning the wrong object
Should be returning the newly created object instead of the original
ndarray...
2024-03-07 14:59:13 +08:00
96b7f29679 core: Implement ndarray.fill 2024-03-07 14:59:13 +08:00
3d2abf73c8 core: Replace ndarray_init_dims IRRT impl with IR impl
Implementation of that function in IR allows for more flexibility in
terms of different integer type widths.
2024-03-07 14:59:13 +08:00
f682e9bf7a core: Match IRRT compile flavor with build profile 2024-03-07 14:59:02 +08:00
b26cb2b360 core: Express member func def IDs as offsets from class def ID 2024-03-06 12:24:39 +08:00
2317516cf6 core: Use tvars from ndarray for class definition 2024-03-04 23:58:02 +08:00
77de24ef74 core: Use BTreeMap for type variable mapping
There have been multiple instances where I had the need to iterate over
type variables, only to discover that the traversal order is arbitrary.

This commit fixes that by adding SortedMapping, which utilizes BTreeMap
internally to guarantee a traversal order. All instances of VarMap are
now refactored to use this to ensure that type variables are iterated in
 the order of its variable ID, which should be monotonically incremented
 by the unifier.
2024-03-04 23:56:04 +08:00
234a6bde2a core: Use TObj for NDArray 2024-03-01 15:41:55 +08:00
c3db6297d9 core: Add primitive definition-id list
So that we have a single ground truth for the definition IDs of
primitive types.
2024-03-01 11:29:10 +08:00
82fdb02d13 core: Extract LLVM intrinsic functions to their functions 2024-02-23 15:41:06 +08:00
4efdd17513 core: Add missing From implementations for LLVM wrapper classes 2024-02-23 15:41:06 +08:00
49de81ef1e core: Apply clippy suggestions 2024-02-23 15:41:06 +08:00
8492503af2 core: Update cargo dependencies 2024-02-23 15:41:04 +08:00
e1dbe2526a flake: switch to nixpkgs unstable for newer rustc 2024-02-20 15:46:51 +08:00
f37de381ce update dependencies 2024-02-20 13:33:20 +08:00
4452c8986a update ARTIQ version used for PGO profiling 2024-02-20 13:29:00 +08:00
22e831cb76 core: Add test for indexing into ndarray 2024-02-19 17:13:10 +08:00
cc538d221a core: Implement codegen for indexing into ndarray 2024-02-19 17:13:09 +08:00
0d5c53e60c core: Implement type inference for indexing into ndarray 2024-02-19 17:13:09 +08:00
976a9512c1 core: Add const variants to NDArray element getters 2024-02-19 16:56:21 +08:00
1eacaf9afa core: Fix IRRT argument order to ndarray_flatten_index 2024-02-19 16:37:13 +08:00
8c7e44098a core: Fix IRRT implementation of ndarray_flatten_index 2024-02-19 16:37:13 +08:00
282a3e1911 core: Fix typo in error message 2024-02-14 16:26:13 +08:00
5cecb2bb74 core: Fix Literal use in variable type annotation 2024-02-06 18:16:14 +08:00
1963c30744 core: Use Display output for locations 2024-02-06 18:11:51 +08:00
27011f385b core: Add location to non-primitive value return error 2024-02-02 12:49:21 +08:00
d6302b6ec8 core: Allow tuple of primitives to be returned 2024-02-02 12:48:52 +08:00
fef4b2a5ce standalone: Disable tests requiring return of non-primitive values 2024-01-29 12:49:50 +08:00
b3736c3e99 core: Disallow returning of non-primitive values
Non-primitive values are represented by an `alloca`-ed value in the
function body, and when the pointer is returned from the function, the
`alloca`-ed object is deallocated on the stack.

Related to #54.
2024-01-29 12:49:24 +08:00
e328e44c9a update MSYS2 2024-01-26 15:55:45 +08:00
9e4e90f8a0 update dependencies 2024-01-26 15:52:48 +08:00
8470915809 core: Add NDArrayValue and helper functions 2024-01-25 15:51:39 +08:00
148900302e core: Add RangeValue and helper functions 2024-01-25 15:51:39 +08:00
5ee08b585f core: Add ListValue and helper functions 2024-01-25 15:51:39 +08:00
f1581299fc core: Minor changes to IRRT
Add missing documentation, remove redundant lifetime variables, and fix
typos.
2024-01-25 15:50:53 +08:00
af95ba5012 standalone: Add debug flag to run_demo.sh
Allows running demos using the debug build instead of the (default)
release build.
2024-01-25 15:50:53 +08:00
9c9756be33 standalone: Use size_t in demo.c 2024-01-25 15:50:53 +08:00
2a922c7480 artiq: Fix source module of NDArray
Should be `numpy.typing` instead of `numpy`.
2024-01-17 10:40:08 +08:00
e3e2c36ef4 core: Mark TNDArray and TLiteral as unimplemented in tests 2024-01-17 09:58:14 +08:00
4f9a0110c4 meta: Update insta snapshots 2024-01-17 09:49:50 +08:00
12c0eed0a3 core: Fix compilation of tests 2024-01-17 09:49:49 +08:00
c679474f5c standalone: Fix redefinition of ndarray consumer functions 2024-01-17 09:38:13 +08:00
ab3fa05996 demo: use portable format strings 2024-01-10 18:35:35 +08:00
140f8f8a08 core: Implement most ndarray-creation functions 2023-12-22 16:29:55 +08:00
27fcf8926e core: Implement ndarray constructor and numpy.empty 2023-12-22 16:29:54 +08:00
afa7d9b100 core: Implement helper for creation of generic ndarray 2023-12-21 15:39:49 +08:00
c395472094 core: Initial infrastructure for ndarray 2023-12-21 15:39:46 +08:00
03870f222d core: Extract special method handling in type inferencer
To prepare for more special handling with methods.
2023-12-21 15:38:26 +08:00
e435b25756 core: Allow implicit promotions of integral literals
It should not matter, since it is the value of the literal that matters
with respect to the const generic variable.
2023-12-21 15:21:08 +08:00
bd792904f9 core: Add size_t to primitive store
Used for ndims in ndarray.
2023-12-21 15:20:31 +08:00
1c3a823670 core: Do not discard value names for IRRT 2023-12-20 15:16:02 +08:00
f01d833d48 standalone: Add missing parenthesis 2023-12-20 15:15:47 +08:00
9d64e606f4 core: Reject multiple literal bounds
This is currently broken due to how we handle function calls in the
unifier.
2023-12-18 10:04:25 +08:00
6dccb343bb Revert "core: Do not keep unification result for function arguments"
This reverts commit f09f3c27a5.
2023-12-18 10:01:23 +08:00
d47534e2ad interpret_demo: add typing.Literal 2023-12-18 08:50:49 +08:00
8886964776 core: Remove redundant argument in type annotation parsing 2023-12-16 18:40:48 +08:00
f09f3c27a5 core: Do not keep unification result for function arguments
For some reason, when unifying a function call parameter with an
argument, subsequent calls to the same function will only accept the
type of the substituted argument.

This affect snippets like:

```
def make1() -> C[Literal[1]]:
    return ...

def make2() -> C[Literal[2]]:
    return ...

def consume(instance: C[Literal[1, 2]]):
    pass

consume(make1())
consume(make2())
```

The last statement will result in a compiler error, as the parameter of
consume is replaced with C[Literal[1]].

We fix this by getting a snapshot before performing unification, and
restoring the snapshot after unification succeeds.
2023-12-16 18:40:48 +08:00
0bbc9ce6f5 core: Deduplicate values in Literal
Matches the behavior with `typing.Literal`.
2023-12-16 18:40:48 +08:00
457d3b6cd7 core: Refactor generic constants to Literal
Better matches the syntax of `typing.Literal`.
2023-12-16 18:40:48 +08:00
5f692debd8 core: Add PrimitiveStore into Unifier
This will be used during unification between a const generic variable
and a `Literal`.
2023-12-16 18:40:48 +08:00
c7735d935b standalone: Output id of undefined identifier 2023-12-16 18:40:48 +08:00
b47ac1b89b core: Minor formatting cleanup 2023-12-15 17:46:44 +08:00
160 changed files with 33549 additions and 9816 deletions

32
.clang-format Normal file
View File

@ -0,0 +1,32 @@
BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

1
.clippy.toml Normal file
View File

@ -0,0 +1 @@
doc-valid-idents = ["CPython", "NumPy", ".."]

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
/nac3standalone/demo/linalg/target
nix/windows/msys2 nix/windows/msys2

24
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,24 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [pre-commit]
repos:
- repo: local
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [develop, -c, cargo, fmt, --all]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [develop, -c, cargo, clippy, --tests]

1043
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -4,6 +4,7 @@ members = [
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
"nac3core/nac3core_derive",
"nac3standalone", "nac3standalone",
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",

View File

@ -51,3 +51,12 @@ Use ``nix develop`` in this repository to enter a development shell.
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``. If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``. Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.
### Pre-Commit Hooks
You are strongly recommended to use the provided pre-commit hooks to automatically reformat files and check for non-optimal Rust practices using Clippy. Run `pre-commit install` to install the hook and `pre-commit` will automatically run `cargo fmt` and `cargo clippy` for you.
Several things to note:
- If `cargo fmt` or `cargo clippy` returns an error, the pre-commit hook will fail. You should fix all errors before trying to commit again.
- If `cargo fmt` reformats some files, the pre-commit hook will also fail. You should review the changes and, if satisfied, try to commit again.

8
flake.lock generated
View File

@ -2,16 +2,16 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1701389149, "lastModified": 1733940404,
"narHash": "sha256-rU1suTIEd5DGCaAXKW6yHoCfR1mnYjOXQFOaH7M23js=", "narHash": "sha256-Pj39hSoUA86ZePPF/UXiYHHM7hMIkios8TYG29kQT4g=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "5de0b32be6e85dc1a9404c75131316e4ffbc634c", "rev": "5d67ea6b4b63378b9c13be21e2ec9d1afc921713",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "NixOS", "owner": "NixOS",
"ref": "nixos-23.11", "ref": "nixos-unstable",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }

View File

@ -1,11 +1,12 @@
{ {
description = "The third-generation ARTIQ compiler"; description = "The third-generation ARTIQ compiler";
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-23.11; inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
outputs = { self, nixpkgs }: outputs = { self, nixpkgs }:
let let
pkgs = import nixpkgs { system = "x86_64-linux"; }; pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec { in rec {
packages.x86_64-linux = rec { packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {}; llvm-nac3 = pkgs.callPackage ./nix/llvm {};
@ -15,6 +16,22 @@
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
''; '';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule ( nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage rec { pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq"; name = "nac3artiq";
@ -24,15 +41,19 @@
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
passthru.cargoLock = cargoLock; passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase = checkPhase =
'' ''
echo "Running Cargo tests..." echo "Checking nac3standalone demos..."
pushd nac3standalone/demo pushd nac3standalone/demo
patchShebangs . patchShebangs .
export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
popd popd
echo "Running Cargo tests..."
cargoCheckHook cargoCheckHook
''; '';
installPhase = installPhase =
@ -86,18 +107,18 @@
(pkgs.fetchFromGitHub { (pkgs.fetchFromGitHub {
owner = "m-labs"; owner = "m-labs";
repo = "sipyco"; repo = "sipyco";
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e"; rev = "094a6cd63ffa980ef63698920170e50dc9ba77fd";
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc="; sha256 = "sha256-PPnAyDedUQ7Og/Cby9x5OT9wMkNGTP8GS53V6N/dk4w=";
}) })
(pkgs.fetchFromGitHub { (pkgs.fetchFromGitHub {
owner = "m-labs"; owner = "m-labs";
repo = "artiq"; repo = "artiq";
rev = "8b4572f9cad34ac0c2b6f6bba9382e7b59b2f93b"; rev = "28c9de3e251daa89a8c9fd79d5ab64a3ec03bac6";
sha256 = "sha256-O/0sUSxxXU1AL9cmT9qdzCkzdOKREBNftz22/8ouQcc="; sha256 = "sha256-vAvpbHc5B+1wtG8zqN7j9dQE1ON+i22v+uqA+tw6Gak=";
}) })
]; ];
buildInputs = [ buildInputs = [
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ])) (python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb ps.platformdirs nac3artiq-instrumented ]))
pkgs.llvmPackages_14.llvm.out pkgs.llvmPackages_14.llvm.out
]; ];
phases = [ "buildPhase" "installPhase" ]; phases = [ "buildPhase" "installPhase" ];
@ -147,7 +168,7 @@
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
llvmPackages_14.clang # demo (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc
@ -157,8 +178,14 @@
# development tools # development tools
cargo-insta cargo-insta
clippy clippy
pre-commit
rustfmt rustfmt
]; ];
shellHook =
''
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -9,18 +9,13 @@ name = "nac3artiq"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
itertools = "0.12" itertools = "0.13"
pyo3 = { version = "0.20", features = ["extension-module"] } pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12" parking_lot = "0.12"
tempfile = "3.8" tempfile = "3.13"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" } nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.2"
default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features] [features]
init-llvm-profile = [] init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -7,33 +7,6 @@ class EmbeddingMap:
self.function_map = {} self.function_map = {}
self.attributes_writeback = [] self.attributes_writeback = []
# preallocate exception names
self.preallocate_runtime_exception_names(["RuntimeError",
"RTIOUnderflow",
"RTIOOverflow",
"RTIODestinationUnreachable",
"DMAError",
"I2CError",
"CacheError",
"SPIError",
"0:ZeroDivisionError",
"0:IndexError",
"0:ValueError",
"0:RuntimeError",
"0:AssertionError",
"0:KeyError",
"0:NotImplementedError",
"0:OverflowError",
"0:IOError",
"0:UnwrapNoneError"])
def preallocate_runtime_exception_names(self, names):
for i, name in enumerate(names):
if ":" not in name:
name = "0:artiq.coredevice.exceptions." + name
exn_id = self.store_str(name)
assert exn_id == i
def store_function(self, key, fun): def store_function(self, key, fun):
self.function_map[key] = fun self.function_map[key] = fun
return key return key

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -112,10 +112,15 @@ def extern(function):
register_function(function) register_function(function)
return function return function
def rpc(function):
"""Decorates a function declaration defined by the core device runtime.""" def rpc(arg=None, flags={}):
register_function(function) """Decorates a function or method to be executed on the host interpreter."""
return function if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def kernel(function_or_method): def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device.""" """Decorates a function or method to be executed on the core device."""
@ -201,7 +206,7 @@ class Core:
embedding = EmbeddingMap() embedding = EmbeddingMap()
if allow_registration: if allow_registration:
compiler.analyze(registered_functions, registered_classes) compiler.analyze(registered_functions, registered_classes, set())
allow_registration = False allow_registration = False
if hasattr(method, "__self__"): if hasattr(method, "__self__"):

26
nac3artiq/demo/str_abi.py Normal file
View File

@ -0,0 +1,26 @@
from min_artiq import *
from numpy import ndarray, zeros as np_zeros
@nac3
class StrFail:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def hello(self, arg: str):
pass
@kernel
def consume_ndarray(self, arg: ndarray[str, 1]):
pass
def run(self):
self.hello("world")
self.consume_ndarray(np_zeros([10], dtype=str))
if __name__ == "__main__":
StrFail().run()

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
core: KernelInvariant[Core]
attr1: KernelInvariant[str]
attr2: KernelInvariant[int32]
def __init__(self):
self.core = Core()
self.attr2 = 32
self.attr1 = "SAMPLE"
@kernel
def run(self):
print_int32(self.attr2)
self.attr1
if __name__ == "__main__":
Demo().run()

View File

@ -0,0 +1,40 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

File diff suppressed because it is too large Load Diff

View File

@ -1,60 +1,74 @@
use std::collections::{HashMap, HashSet}; #![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
use std::fs; #![warn(clippy::pedantic)]
use std::io::Write; #![allow(
use std::process::Command; unsafe_op_in_unsafe_fn,
use std::rc::Rc; clippy::cast_possible_truncation,
use std::sync::Arc; clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
use inkwell::{ use std::{
memory_buffer::MemoryBuffer, collections::{HashMap, HashSet},
module::{Linkage, Module}, fs,
passes::PassBuilderOptions, io::Write,
support::is_multithreaded, process::Command,
targets::*, rc::Rc,
OptimizationLevel, sync::Arc,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::{CodeGenLLVMOptions, CodeGenTargetMachineOptions, gen_func_impl};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier};
use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::create_exception;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use pyo3::{
use nac3core::{ create_exception, exceptions,
codegen::irrt::load_irrt, prelude::*,
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry}, types::{PyBytes, PyDict, PyNone, PySet},
symbol_resolver::SymbolResolver,
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef,
},
typecheck::typedef::{FunSignature, FuncArg},
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
}; };
use nac3ld::Linker;
use tempfile::{self, TempDir}; use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback; use nac3core::{
use crate::{ codegen::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, concrete_type::ConcreteTypeStore, gen_func_impl, irrt::load_irrt, CodeGenLLVMOptions,
symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore}, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, WithCall, WorkerRegistry,
},
inkwell::{
context::Context,
memory_buffer::MemoryBuffer,
module::{FlagBehavior, Linkage, Module},
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel,
},
nac3parser::{
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
parser::parse_program,
},
symbol_resolver::SymbolResolver,
toplevel::{
builtins::get_exn_constructor,
composer::{BuiltinFuncCreator, BuiltinFuncSpec, ComposerConfig, TopLevelComposer},
DefinitionId, GenCall, TopLevelDef,
},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
}; };
use nac3ld::Linker;
use codegen::{
attributes_writeback, gen_core_log, gen_rtio_log, rpc_codegen_callback, ArtiqCodeGenerator,
};
use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver};
use timeline::TimeFns;
mod codegen; mod codegen;
mod symbol_resolver; mod symbol_resolver;
mod timeline; mod timeline;
use timeline::TimeFns;
#[derive(PartialEq, Clone, Copy)] #[derive(PartialEq, Clone, Copy)]
enum Isa { enum Isa {
Host, Host,
@ -63,6 +77,17 @@ enum Isa {
CortexA9, CortexA9,
} }
impl Isa {
/// Returns the number of bits in `size_t` for the [`Isa`].
fn get_size_type(self) -> u32 {
if self == Isa::Host {
64u32
} else {
32u32
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct PrimitivePythonId { pub struct PrimitivePythonId {
int: u64, int: u64,
@ -73,7 +98,11 @@ pub struct PrimitivePythonId {
float: u64, float: u64,
float64: u64, float64: u64,
bool: u64, bool: u64,
np_bool_: u64,
string: u64,
np_str_: u64,
list: u64, list: u64,
ndarray: u64,
tuple: u64, tuple: u64,
typevar: u64, typevar: u64,
const_generic_marker: u64, const_generic_marker: u64,
@ -93,7 +122,7 @@ struct Nac3 {
isa: Isa, isa: Isa,
time_fns: &'static (dyn TimeFns + Sync), time_fns: &'static (dyn TimeFns + Sync),
primitive: PrimitiveStore, primitive: PrimitiveStore,
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>, builtins: Vec<BuiltinFuncSpec>,
pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>, pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
primitive_ids: PrimitivePythonId, primitive_ids: PrimitivePythonId,
working_directory: TempDir, working_directory: TempDir,
@ -113,22 +142,38 @@ impl Nac3 {
module: &PyObject, module: &PyObject,
registered_class_ids: &HashSet<u64>, registered_class_ids: &HashSet<u64>,
) -> PyResult<()> { ) -> PyResult<()> {
let (module_name, source_file) = Python::with_gil(|py| -> PyResult<(String, String)> { let (module_name, source_file, source) =
let module: &PyAny = module.extract(py)?; Python::with_gil(|py| -> PyResult<(String, String, String)> {
Ok((module.getattr("__name__")?.extract()?, module.getattr("__file__")?.extract()?)) let module: &PyAny = module.extract(py)?;
})?; let source_file = module.getattr("__file__");
let (source_file, source) = if let Ok(source_file) = source_file {
let source_file = source_file.extract()?;
(
source_file,
fs::read_to_string(source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!(
"failed to read input file: {e}"
))
})?,
)
} else {
// kernels submitted by content have no file
// but still can provide source by StringLoader
let get_src_fn = module
.getattr("__loader__")?
.extract::<PyObject>()?
.getattr(py, "get_source")?;
("<expcontent>", get_src_fn.call1(py, (PyNone::get(py),))?.extract(py)?)
};
Ok((module.getattr("__name__")?.extract()?, source_file.to_string(), source))
})?;
let source = fs::read_to_string(&source_file).map_err(|e| {
exceptions::PyIOError::new_err(format!("failed to read input file: {e}"))
})?;
let parser_result = parse_program(&source, source_file.into()) let parser_result = parse_program(&source, source_file.into())
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?; .map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
for mut stmt in parser_result { for mut stmt in parser_result {
let include = match stmt.node { let include = match stmt.node {
StmtKind::ClassDef { StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => {
ref decorator_list, ref mut body, ref mut bases, ..
} => {
let nac3_class = decorator_list.iter().any(|decorator| { let nac3_class = decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "nac3" id.to_string() == "nac3"
@ -148,7 +193,8 @@ impl Nac3 {
if *id == "Exception".into() { if *id == "Exception".into() {
Ok(true) Ok(true)
} else { } else {
let base_obj = module.getattr(py, id.to_string().as_str())?; let base_obj =
module.getattr(py, id.to_string().as_str())?;
let base_id = id_fn.call1((base_obj,))?.extract()?; let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id)) Ok(registered_class_ids.contains(&base_id))
} }
@ -161,10 +207,8 @@ impl Nac3 {
body.retain(|stmt| { body.retain(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let Some(id) = decorator_id_string(decorator) {
id.to_string() == "kernel" id == "kernel" || id == "portable" || id == "rpc"
|| id.to_string() == "portable"
|| id.to_string() == "rpc"
} else { } else {
false false
} }
@ -177,9 +221,8 @@ impl Nac3 {
} }
StmtKind::FunctionDef { ref decorator_list, .. } => { StmtKind::FunctionDef { ref decorator_list, .. } => {
decorator_list.iter().any(|decorator| { decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let Some(id) = decorator_id_string(decorator) {
let id = id.to_string(); id == "extern" || id == "kernel" || id == "portable" || id == "rpc"
id == "extern" || id == "portable" || id == "kernel" || id == "rpc"
} else { } else {
false false
} }
@ -232,7 +275,7 @@ impl Nac3 {
arg_names.len(), arg_names.len(),
)); ));
} }
for (i, FuncArg { ty, default_value, name }) in args.iter().enumerate() { for (i, FuncArg { ty, default_value, name, .. }) in args.iter().enumerate() {
let in_name = match arg_names.get(i) { let in_name = match arg_names.get(i) {
Some(n) => n, Some(n) => n,
None if default_value.is_none() => { None if default_value.is_none() => {
@ -268,6 +311,64 @@ impl Nac3 {
None None
} }
/// Returns a [`Vec`] of builtins that needs to be initialized during method compilation time.
fn get_lateinit_builtins() -> Vec<Box<BuiltinFuncCreator>> {
vec![
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"core_log".into(),
FunSignature {
args: vec![FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
}],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_core_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
Box::new(|primitives, unifier| {
let arg_ty = unifier.get_fresh_var(Some("T".into()), None);
(
"rtio_log".into(),
FunSignature {
args: vec![
FuncArg {
name: "channel".into(),
ty: primitives.str,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "arg".into(),
ty: arg_ty.ty,
default_value: None,
is_vararg: false,
},
],
ret: primitives.none,
vars: into_var_map([arg_ty]),
},
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_rtio_log(ctx, &obj, fun, &args, generator)?;
Ok(None)
}))),
)
}),
]
}
fn compile_method<T>( fn compile_method<T>(
&self, &self,
obj: &PyAny, obj: &PyAny,
@ -277,9 +378,12 @@ impl Nac3 {
py: Python, py: Python,
link_fn: &dyn Fn(&Module) -> PyResult<T>, link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> { ) -> PyResult<T> {
let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
self.builtins.clone(), self.builtins.clone(),
Self::get_lateinit_builtins(),
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
size_t,
); );
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
@ -327,8 +431,9 @@ impl Nac3 {
let class_obj; let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node { if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap(); let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() && if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
class.getattr("artiq_builtin").is_err() { && class.getattr("artiq_builtin").is_err()
{
class_obj = Some(class); class_obj = Some(class);
} else { } else {
class_obj = None; class_obj = None;
@ -353,7 +458,6 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
name_to_pyid: name_to_pyid.clone(), name_to_pyid: name_to_pyid.clone(),
module: module.clone(), module: module.clone(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
@ -374,19 +478,35 @@ impl Nac3 {
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path, false) .register_top_level(stmt.clone(), Some(resolver.clone()), path, false)
.map_err(|e| { .map_err(|e| {
CompileError::new_err(format!( CompileError::new_err(format!("compilation failed\n----------\n{e}"))
"compilation failed\n----------\n{e}"
))
})?; })?;
if let Some(class_obj) = class_obj { if let Some(class_obj) = class_obj {
self.exception_ids.write().insert(def_id.0, store_obj.call1(py, (class_obj, ))?.extract(py)?); self.exception_ids
.write()
.insert(def_id.0, store_obj.call1(py, (class_obj,))?.extract(py)?);
} }
match &stmt.node { match &stmt.node {
StmtKind::FunctionDef { decorator_list, .. } => { StmtKind::FunctionDef { decorator_list, .. } => {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { if decorator_list
store_fun.call1(py, (def_id.0.into_py(py), module.getattr(py, name.to_string().as_str()).unwrap())).unwrap(); .iter()
rpc_ids.push((None, def_id)); .any(|decorator| decorator_id_string(decorator) == Some("rpc".to_string()))
{
store_fun
.call1(
py,
(
def_id.0.into_py(py),
module.getattr(py, name.to_string().as_str()).unwrap(),
),
)
.unwrap();
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
rpc_ids.push((None, def_id, is_async));
} }
} }
StmtKind::ClassDef { name, body, .. } => { StmtKind::ClassDef { name, body, .. } => {
@ -394,19 +514,26 @@ impl Nac3 {
let class_obj = module.getattr(py, class_name.as_str()).unwrap(); let class_obj = module.getattr(py, class_name.as_str()).unwrap();
for stmt in body { for stmt in body {
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node { if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
if decorator_list.iter().any(|decorator| matches!(decorator.node, ExprKind::Name { id, .. } if id == "rpc".into())) { if decorator_list.iter().any(|decorator| {
decorator_id_string(decorator) == Some("rpc".to_string())
}) {
let is_async = decorator_list.iter().any(|decorator| {
decorator_get_flags(decorator)
.iter()
.any(|constant| *constant == Constant::Str("async".into()))
});
if name == &"__init__".into() { if name == &"__init__".into() {
return Err(CompileError::new_err(format!( return Err(CompileError::new_err(format!(
"compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})", "compilation failed\n----------\nThe constructor of class {} should not be decorated with rpc decorator (at {})",
class_name, stmt.location class_name, stmt.location
))); )));
} }
rpc_ids.push((Some((class_obj.clone(), *name)), def_id)); rpc_ids.push((Some((class_obj.clone(), *name)), def_id, is_async));
} }
} }
} }
} }
_ => () _ => (),
} }
let id = *name_to_pyid.get(&name).unwrap(); let id = *name_to_pyid.get(&name).unwrap();
@ -445,28 +572,36 @@ impl Nac3 {
pyid_to_type: pyid_to_type.clone(), pyid_to_type: pyid_to_type.clone(),
primitive_ids: self.primitive_ids.clone(), primitive_ids: self.primitive_ids.clone(),
global_value_ids: global_value_ids.clone(), global_value_ids: global_value_ids.clone(),
class_names: Mutex::default(),
id_to_pyval: RwLock::default(), id_to_pyval: RwLock::default(),
id_to_primitive: RwLock::default(), id_to_primitive: RwLock::default(),
field_to_val: RwLock::default(), field_to_val: RwLock::default(),
name_to_pyid, name_to_pyid,
module: module.to_object(py), module: module.to_object(py),
helper, helper: helper.clone(),
string_store: self.string_store.clone(), string_store: self.string_store.clone(),
exception_ids: self.exception_ids.clone(), exception_ids: self.exception_ids.clone(),
deferred_eval_store: self.deferred_eval_store.clone(), deferred_eval_store: self.deferred_eval_store.clone(),
}); });
let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; let resolver =
Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer let (_, def_id, _) = composer
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap(); .unwrap();
// Process IRRT
let context = Context::create();
let irrt = load_irrt(&context, resolver.as_ref());
let fun_signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = let signature = store.from_signature(
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); &mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) { if let Err(e) = composer.start_analysis(true) {
@ -485,24 +620,21 @@ impl Nac3 {
msg.unwrap_or(e.iter().sorted().join("\n----------\n")) msg.unwrap_or(e.iter().sorted().join("\n----------\n"))
))) )))
} else { } else {
Err(CompileError::new_err( Err(CompileError::new_err(format!(
format!( "compilation failed\n----------\n{}",
"compilation failed\n----------\n{}", e.iter().sorted().join("\n----------\n"),
e.iter().sorted().join("\n----------\n"), )))
), };
))
}
} }
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
{ {
let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
for (class_data, id) in &rpc_ids { for (class_data, id, is_async) in &rpc_ids {
let mut def = defs[id.0].write(); let mut def = defs[id.0].write();
match &mut *def { match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => { TopLevelDef::Function { codegen_callback, .. } => {
*codegen_callback = Some(rpc_codegen.clone()); *codegen_callback = Some(rpc_codegen_callback(*is_async));
} }
TopLevelDef::Class { methods, .. } => { TopLevelDef::Class { methods, .. } => {
let (class_def, method_name) = class_data.as_ref().unwrap(); let (class_def, method_name) = class_data.as_ref().unwrap();
@ -513,19 +645,26 @@ impl Nac3 {
if let TopLevelDef::Function { codegen_callback, .. } = if let TopLevelDef::Function { codegen_callback, .. } =
&mut *defs[id.0].write() &mut *defs[id.0].write()
{ {
*codegen_callback = Some(rpc_codegen.clone()); *codegen_callback = Some(rpc_codegen_callback(*is_async));
store_fun store_fun
.call1( .call1(
py, py,
( (
id.0.into_py(py), id.0.into_py(py),
class_def.getattr(py, name.to_string().as_str()).unwrap(), class_def
.getattr(py, name.to_string().as_str())
.unwrap(),
), ),
) )
.unwrap(); .unwrap();
} }
} }
} }
TopLevelDef::Variable { .. } => {
return Err(CompileError::new_err(String::from(
"Unsupported @rpc annotation on global variable",
)))
}
} }
} }
} }
@ -534,7 +673,8 @@ impl Nac3 {
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write(); let mut definition = defs[def_id.0].write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
&mut *definition else { &mut *definition
else {
unreachable!() unreachable!()
}; };
@ -545,29 +685,12 @@ impl Nac3 {
let task = CodeGenTask { let task = CodeGenTask {
subst: Vec::default(), subst: Vec::default(),
symbol_name: "__modinit__".to_string(), symbol_name: "__modinit__".to_string(),
body: instance.body,
signature,
resolver: resolver.clone(),
store,
unifier_index: instance.unifier_id,
calls: instance.calls,
id: 0,
};
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature =
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask {
subst: Vec::default(),
symbol_name: "attributes_writeback".to_string(),
body: Arc::new(Vec::default()), body: Arc::new(Vec::default()),
signature, signature,
resolver, resolver,
store, store,
unifier_index: instance.unifier_id, unifier_index: instance.unifier_id,
calls: Arc::new(HashMap::default()), calls: instance.calls,
id: 0, id: 0,
}; };
@ -580,7 +703,9 @@ impl Nac3 {
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let size_t = if self.isa == Isa::Host { 64 } else { 32 }; let size_t = context
.ptr_sized_int_type(&self.get_llvm_target_machine().get_target_data(), None)
.get_bit_width();
let num_threads = if is_multithreaded() { 4 } else { 1 }; let num_threads = if is_multithreaded() { 4 } else { 1 };
let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect(); let thread_names: Vec<String> = (0..num_threads).map(|_| "main".to_string()).collect();
let threads: Vec<_> = thread_names let threads: Vec<_> = thread_names
@ -589,49 +714,81 @@ impl Nac3 {
.collect(); .collect();
let membuffer = membuffers.clone(); let membuffer = membuffers.clone();
let mut has_return = false;
py.allow_threads(|| { py.allow_threads(|| {
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) =
threads, WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
top_level.clone(),
&self.llvm_options,
&f
);
registry.add_task(task);
registry.wait_tasks_complete(handles);
let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); let mut generator = ArtiqCodeGenerator::new("main".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create(); let context = Context::create();
let module = context.create_module("attributes_writeback"); let module = context.create_module("main");
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag(
"Debug Info Version",
FlagBehavior::Warning,
context.i32_type().const_int(3, false),
);
module.add_basic_value_flag(
"Dwarf Version",
FlagBehavior::Warning,
context.i32_type().const_int(4, false),
);
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl(&context, &mut generator, &registry, builder, module, let (_, module, _) = gen_func_impl(
attributes_writeback_task, |generator, ctx| { &context,
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes) &mut generator,
}).unwrap(); &registry,
builder,
module,
task,
|generator, ctx| {
assert_eq!(instance.body.len(), 1, "toplevel module should have 1 statement");
let StmtKind::Expr { value: ref expr, .. } = instance.body[0].node else {
unreachable!("toplevel statement must be an expression")
};
let ExprKind::Call { .. } = expr.node else {
unreachable!("toplevel expression must be a function call")
};
let return_obj =
generator.gen_expr(ctx, expr)?.map(|value| (expr.custom.unwrap(), value));
has_return = return_obj.is_some();
registry.wait_tasks_complete(handles);
attributes_writeback(
ctx,
generator,
inner_resolver.as_ref(),
&host_attributes,
return_obj,
)
},
)
.unwrap();
let buffer = module.write_bitcode_to_memory(); let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}); });
let context = inkwell::context::Context::create(); embedding_map.setattr("expects_return", has_return).unwrap();
// Link all modules into `main`.
let buffers = membuffers.lock(); let buffers = membuffers.lock();
let main = context let main = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(
buffers.last().unwrap(),
"main",
))
.unwrap(); .unwrap();
for buffer in buffers.iter().skip(1) { for buffer in buffers.iter().rev().skip(1) {
let other = context let other = context
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap(); .unwrap();
main.link_in_module(other) main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
.map_err(|err| CompileError::new_err(err.to_string()))?;
} }
let builder = context.create_builder(); main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap();
builder.position_before(&modinit_return);
builder.build_call(main.get_function("attributes_writeback").unwrap(), &[], "attributes_writeback");
main.link_in_module(load_irrt(&context))
.map_err(|err| CompileError::new_err(err.to_string()))?;
let mut function_iter = main.get_first_function(); let mut function_iter = main.get_first_function();
while let Some(func) = function_iter { while let Some(func) = function_iter {
@ -642,10 +799,7 @@ impl Nac3 {
} }
// Demote all global variables that will not be referenced in the kernel to private // Demote all global variables that will not be referenced in the kernel to private
let preserved_symbols: Vec<&'static [u8]> = vec![ let preserved_symbols: Vec<&'static [u8]> = vec![b"typeinfo", b"now"];
b"typeinfo",
b"now",
];
let mut global_option = main.get_first_global(); let mut global_option = main.get_first_global();
while let Some(global) = global_option { while let Some(global) = global_option {
if !preserved_symbols.contains(&(global.get_name().to_bytes())) { if !preserved_symbols.contains(&(global.get_name().to_bytes())) {
@ -654,7 +808,9 @@ impl Nac3 {
global_option = global.get_next_global(); global_option = global.get_next_global();
} }
let target_machine = self.llvm_options.target let target_machine = self
.llvm_options
.target
.create_target_machine(self.llvm_options.opt_level) .create_target_machine(self.llvm_options.opt_level)
.expect("couldn't create target machine"); .expect("couldn't create target machine");
@ -666,6 +822,20 @@ impl Nac3 {
panic!("Failed to run optimization for module `main`: {}", err.to_string()); panic!("Failed to run optimization for module `main`: {}", err.to_string());
} }
Python::with_gil(|py| {
let string_store = self.string_store.read();
let mut string_store_vec = string_store.iter().collect::<Vec<_>>();
string_store_vec.sort_by(|(_s1, key1), (_s2, key2)| key1.cmp(key2));
for (s, key) in string_store_vec {
let embed_key: i32 = helper.store_str.call1(py, (s,)).unwrap().extract(py).unwrap();
assert_eq!(
embed_key, *key,
"string {s} is out of sync between embedding map (key={embed_key}) and \
the internal string store (key={key})"
);
}
});
link_fn(&main) link_fn(&main)
} }
@ -718,10 +888,42 @@ impl Nac3 {
} }
} }
fn link_with_lld( /// Retrieves the Name.id from a decorator, supports decorators with arguments.
elf_filename: String, fn decorator_id_string(decorator: &Located<ExprKind>) -> Option<String> {
obj_filename: String, if let ExprKind::Name { id, .. } = decorator.node {
) -> PyResult<()>{ // Bare decorator
return Some(id.to_string());
} else if let ExprKind::Call { func, .. } = &decorator.node {
// Decorators that are calls (e.g. "@rpc()") have Call for the node,
// need to extract the id from within.
if let ExprKind::Name { id, .. } = func.node {
return Some(id.to_string());
}
}
None
}
/// Retrieves flags from a decorator, if any.
fn decorator_get_flags(decorator: &Located<ExprKind>) -> Vec<Constant> {
let mut flags = vec![];
if let ExprKind::Call { keywords, .. } = &decorator.node {
for keyword in keywords {
if keyword.node.arg != Some("flags".into()) {
continue;
}
if let ExprKind::Set { elts } = &keyword.node.value.node {
for elt in elts {
if let ExprKind::Constant { value, .. } = &elt.node {
flags.push(value.clone());
}
}
}
}
}
flags
}
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
let linker_args = vec![ let linker_args = vec![
"-shared".to_string(), "-shared".to_string(),
"--eh-frame-hdr".to_string(), "--eh-frame-hdr".to_string(),
@ -740,9 +942,7 @@ fn link_with_lld(
return Err(CompileError::new_err("failed to start linker")); return Err(CompileError::new_err("failed to start linker"));
} }
} else { } else {
return Err(CompileError::new_err( return Err(CompileError::new_err("linker returned non-zero status code"));
"linker returned non-zero status code",
));
} }
Ok(()) Ok(())
@ -752,7 +952,7 @@ fn add_exceptions(
composer: &mut TopLevelComposer, composer: &mut TopLevelComposer,
builtin_def: &mut HashMap<StrRef, DefinitionId>, builtin_def: &mut HashMap<StrRef, DefinitionId>,
builtin_ty: &mut HashMap<StrRef, Type>, builtin_ty: &mut HashMap<StrRef, Type>,
error_names: &[&str] error_names: &[&str],
) -> Vec<Type> { ) -> Vec<Type> {
let mut types = Vec::new(); let mut types = Vec::new();
// note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}" // note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}"
@ -765,7 +965,7 @@ fn add_exceptions(
// constructor id // constructor id
def_id + 1, def_id + 1,
&mut composer.unifier, &mut composer.unifier,
&composer.primitives_ty &composer.primitives_ty,
); );
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None)); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None));
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None)); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None));
@ -792,11 +992,11 @@ impl Nac3 {
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
}; };
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; let (primitive, _) = TopLevelComposer::make_primitives(isa.get_size_type());
let builtins = vec![ let builtins = vec![
( (
"now_mu".into(), "now_mu".into(),
FunSignature { args: vec![], ret: primitive.int64, vars: HashMap::new() }, FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new() },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| { Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
Ok(Some(time_fns.emit_now_mu(ctx))) Ok(Some(time_fns.emit_now_mu(ctx)))
}))), }))),
@ -808,13 +1008,15 @@ impl Nac3 {
name: "t".into(), name: "t".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: HashMap::new(), vars: VarMap::new(),
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_at_mu(ctx, arg); time_fns.emit_at_mu(ctx, arg);
Ok(None) Ok(None)
}))), }))),
@ -826,13 +1028,15 @@ impl Nac3 {
name: "dt".into(), name: "dt".into(),
ty: primitive.int64, ty: primitive.int64,
default_value: None, default_value: None,
is_vararg: false,
}], }],
ret: primitive.none, ret: primitive.none,
vars: HashMap::new(), vars: VarMap::new(),
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_delay_mu(ctx, arg); time_fns.emit_delay_mu(ctx, arg);
Ok(None) Ok(None)
}))), }))),
@ -846,8 +1050,9 @@ impl Nac3 {
let types_mod = PyModule::import(py, "types").unwrap(); let types_mod = PyModule::import(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap(); let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),)) let get_attr_id = |obj: &PyModule, attr| {
.unwrap().extract().unwrap(); id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let primitive_ids = PrimitivePythonId { let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()), virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
generic_alias: ( generic_alias: (
@ -856,16 +1061,22 @@ impl Nac3 {
), ),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()), none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"), typevar: get_attr_id(typing_mod, "TypeVar"),
const_generic_marker: get_id(artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap()), const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
),
int: get_attr_id(builtins_mod, "int"), int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"), int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"), int64: get_attr_id(numpy_mod, "int64"),
uint32: get_attr_id(numpy_mod, "uint32"), uint32: get_attr_id(numpy_mod, "uint32"),
uint64: get_attr_id(numpy_mod, "uint64"), uint64: get_attr_id(numpy_mod, "uint64"),
bool: get_attr_id(builtins_mod, "bool"), bool: get_attr_id(builtins_mod, "bool"),
np_bool_: get_attr_id(numpy_mod, "bool_"),
string: get_attr_id(builtins_mod, "str"),
np_str_: get_attr_id(numpy_mod, "str_"),
float: get_attr_id(builtins_mod, "float"), float: get_attr_id(builtins_mod, "float"),
float64: get_attr_id(numpy_mod, "float64"), float64: get_attr_id(numpy_mod, "float64"),
list: get_attr_id(builtins_mod, "list"), list: get_attr_id(builtins_mod, "list"),
ndarray: get_attr_id(numpy_mod, "ndarray"),
tuple: get_attr_id(builtins_mod, "tuple"), tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"), exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
@ -874,6 +1085,48 @@ impl Nac3 {
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap(); fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
let mut string_store: HashMap<String, i32> = HashMap::default();
// Keep this list of exceptions in sync with `EXCEPTION_ID_LOOKUP` in `artiq::firmware::ksupport::eh_artiq`
// The exceptions declared here must be defined in `artiq.coredevice.exceptions`
// Verify synchronization by running the test cases in `artiq.test.coredevice.test_exceptions`
let runtime_exception_names = [
"RTIOUnderflow",
"RTIOOverflow",
"RTIODestinationUnreachable",
"DMAError",
"I2CError",
"CacheError",
"SPIError",
"SubkernelError",
"0:AssertionError",
"0:AttributeError",
"0:IndexError",
"0:IOError",
"0:KeyError",
"0:NotImplementedError",
"0:OverflowError",
"0:RuntimeError",
"0:TimeoutError",
"0:TypeError",
"0:ValueError",
"0:ZeroDivisionError",
"0:LinAlgError",
"UnwrapNoneError",
];
// Preallocate runtime exception names
for (i, name) in runtime_exception_names.iter().enumerate() {
let exn_name = if name.find(':').is_none() {
format!("0:artiq.coredevice.exceptions.{name}")
} else {
(*name).to_string()
};
let id = i32::try_from(i).unwrap();
string_store.insert(exn_name, id);
}
Ok(Nac3 { Ok(Nac3 {
isa, isa,
time_fns, time_fns,
@ -883,17 +1136,22 @@ impl Nac3 {
top_levels: Vec::default(), top_levels: Vec::default(),
pyid_to_def: Arc::default(), pyid_to_def: Arc::default(),
working_directory, working_directory,
string_store: Arc::default(), string_store: Arc::new(string_store.into()),
exception_ids: Arc::default(), exception_ids: Arc::default(),
deferred_eval_store: DeferredEvaluationStore::new(), deferred_eval_store: DeferredEvaluationStore::new(),
llvm_options: CodeGenLLVMOptions { llvm_options: CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: Nac3::get_llvm_target_options(isa), target: Nac3::get_llvm_target_options(isa),
} },
}) })
} }
fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> { fn analyze(
&mut self,
functions: &PySet,
classes: &PySet,
content_modules: &PySet,
) -> PyResult<()> {
let (modules, class_ids) = let (modules, class_ids) =
Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> { Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
let mut modules: HashMap<u64, PyObject> = HashMap::new(); let mut modules: HashMap<u64, PyObject> = HashMap::new();
@ -903,14 +1161,22 @@ impl Nac3 {
let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?; let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
for function in functions { for function in functions {
let module = getmodule_fn.call1((function,))?.extract()?; let module: PyObject = getmodule_fn.call1((function,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module); if !module.is_none(py) {
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
} }
for class in classes { for class in classes {
let module = getmodule_fn.call1((class,))?.extract()?; let module: PyObject = getmodule_fn.call1((class,))?.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module); if !module.is_none(py) {
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
class_ids.insert(id_fn.call1((class,))?.extract()?); class_ids.insert(id_fn.call1((class,))?.extract()?);
} }
for module in content_modules {
let module: PyObject = module.extract()?;
modules.insert(id_fn.call1((&module,))?.extract()?, module);
}
Ok((modules, class_ids)) Ok((modules, class_ids))
})?; })?;
@ -930,7 +1196,7 @@ impl Nac3 {
py: Python, py: Python,
) -> PyResult<()> { ) -> PyResult<()> {
let target_machine = self.get_llvm_target_machine(); let target_machine = self.get_llvm_target_machine();
if self.isa == Isa::Host { if self.isa == Isa::Host {
let link_fn = |module: &Module| { let link_fn = |module: &Module| {
let working_directory = self.working_directory.path().to_owned(); let working_directory = self.working_directory.path().to_owned();
@ -939,7 +1205,7 @@ impl Nac3 {
.expect("couldn't write module to file"); .expect("couldn't write module to file");
link_with_lld( link_with_lld(
filename.to_string(), filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string() working_directory.join("module.o").to_string_lossy().to_string(),
)?; )?;
Ok(()) Ok(())
}; };
@ -975,7 +1241,7 @@ impl Nac3 {
py: Python, py: Python,
) -> PyResult<PyObject> { ) -> PyResult<PyObject> {
let target_machine = self.get_llvm_target_machine(); let target_machine = self.get_llvm_target_machine();
if self.isa == Isa::Host { if self.isa == Isa::Host {
let link_fn = |module: &Module| { let link_fn = |module: &Module| {
let working_directory = self.working_directory.path().to_owned(); let working_directory = self.working_directory.path().to_owned();
@ -987,7 +1253,7 @@ impl Nac3 {
let filename = filename_path.to_str().unwrap(); let filename = filename_path.to_str().unwrap();
link_with_lld( link_with_lld(
filename.to_string(), filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string() working_directory.join("module.o").to_string_lossy().to_string(),
)?; )?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into()) Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,15 @@
use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering}; use itertools::Either;
use nac3core::codegen::CodeGenContext;
use nac3core::{
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
};
/// Functions for manipulating the timeline. /// Functions for manipulating the timeline.
pub trait TimeFns { pub trait TimeFns {
/// Emits LLVM IR for `now_mu`. /// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>; fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
@ -26,32 +32,33 @@ impl TimeFns for NowPinningTimeFns64 {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = let now_hiptr = ctx
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr"); .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { .map(BasicValueEnum::into_pointer_value)
unreachable!() .unwrap();
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}; }
.unwrap();
let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = ( let now_hi = ctx
ctx.builder.build_load(now_hiptr, "now.hi"), .builder
ctx.builder.build_load(now_loptr, "now.lo"), .build_load(now_hiptr, "now.hi")
) else { .map(BasicValueEnum::into_int_value)
unreachable!() .unwrap();
}; let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder.build_left_shift( let shifted_hi =
zext_hi, ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
i64_type.const_int(32, false), let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
"", ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "");
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").into()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -59,105 +66,100 @@ impl TimeFns for NowPinningTimeFns64 {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let BasicValueEnum::IntValue(time) = t else { let time = t.into_int_value();
unreachable!()
};
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), .builder
i32_type, .build_int_truncate(
"", ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
); i32_type,
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); "",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder.build_bitcast( let now_hiptr = ctx
now, .builder
i32_type.ptr_type(AddressSpace::default()), .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
"now.hi.addr", .map(BasicValueEnum::into_pointer_value)
); .unwrap();
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}; }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = let now_hiptr = ctx
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr"); .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { .map(BasicValueEnum::into_pointer_value)
unreachable!() .unwrap();
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}; }
.unwrap();
let ( let now_hi = ctx
BasicValueEnum::IntValue(now_hi), .builder
BasicValueEnum::IntValue(now_lo), .build_load(now_hiptr, "now.hi")
BasicValueEnum::IntValue(dt), .map(BasicValueEnum::into_int_value)
) = ( .unwrap();
ctx.builder.build_load(now_hiptr, "now.hi"), let now_lo = ctx
ctx.builder.build_load(now_loptr, "now.lo"), .builder
dt, .build_load(now_loptr, "now.lo")
) else { .map(BasicValueEnum::into_int_value)
unreachable!() .unwrap();
}; let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder.build_left_shift( let shifted_hi =
zext_hi, ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
i64_type.const_int(32, false), let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
"", let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "");
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now");
let time = ctx.builder.build_int_add(now_val, dt, "time"); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift( .builder
time, .build_int_truncate(
i64_type.const_int(32, false), ctx.builder
false, .build_right_shift(time, i64_type.const_int(32, false), false, "")
"", .unwrap(),
), i32_type,
i32_type, "time.hi",
"time.hi", )
); .unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
@ -174,16 +176,16 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); let now_raw = ctx
.builder
let BasicValueEnum::IntValue(now_raw) = now_raw else { .build_load(now.as_pointer_value(), "now")
unreachable!() .map(BasicValueEnum::into_int_value)
}; .unwrap();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu").into() ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -191,48 +193,44 @@ impl TimeFns for NowPinningTimeFns {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let BasicValueEnum::IntValue(time) = t else { let time = t.into_int_value();
unreachable!()
};
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift(time, i64_32, false, ""), .builder
i32_type, .build_int_truncate(
"time.hi", ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
); i32_type,
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); "time.hi",
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder.build_bitcast( let now_hiptr = ctx
now, .builder
i32_type.ptr_type(AddressSpace::default()), .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
"now.hi.addr", .map(BasicValueEnum::into_pointer_value)
); .unwrap();
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}; }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
@ -240,41 +238,45 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), ""); let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) else { let dt = dt.into_int_value();
unreachable!()
};
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val"); let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time"); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), .builder
i32_type, .build_int_truncate(
"now_trunc", ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
); i32_type,
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); "now_trunc",
let now_hiptr = ctx.builder.build_bitcast( )
now, .unwrap();
i32_type.ptr_type(AddressSpace::default()), let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
"now.hi.addr", let now_hiptr = ctx
); .builder
.build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { .map(BasicValueEnum::into_pointer_value)
unreachable!() .unwrap();
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}; }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
@ -289,7 +291,11 @@ impl TimeFns for ExternTimeFns {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
}); });
ctx.builder.build_call(now_mu, &[], "now_mu").try_as_basic_value().left().unwrap() ctx.builder
.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -300,14 +306,10 @@ impl TimeFns for ExternTimeFns {
None, None,
) )
}); });
ctx.builder.build_call(at_mu, &[t.into()], "at_mu"); ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"delay_mu", "delay_mu",
@ -315,7 +317,7 @@ impl TimeFns for ExternTimeFns {
None, None,
) )
}); });
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu"); ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
} }
} }

View File

@ -10,7 +10,6 @@ constant-optimization = ["fold"]
fold = [] fold = []
[dependencies] [dependencies]
lazy_static = "1.4"
parking_lot = "0.12" parking_lot = "0.12"
string-interner = "0.14" string-interner = "0.17"
fxhash = "0.2" fxhash = "0.2"

File diff suppressed because it is too large Load Diff

View File

@ -28,12 +28,12 @@ impl From<bool> for Constant {
} }
impl From<i32> for Constant { impl From<i32> for Constant {
fn from(i: i32) -> Constant { fn from(i: i32) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
impl From<i64> for Constant { impl From<i64> for Constant {
fn from(i: i64) -> Constant { fn from(i: i64) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
@ -50,6 +50,7 @@ pub enum ConversionFlag {
} }
impl ConversionFlag { impl ConversionFlag {
#[must_use]
pub fn try_from_byte(b: u8) -> Option<Self> { pub fn try_from_byte(b: u8) -> Option<Self> {
match b { match b {
b's' => Some(Self::Str), b's' => Some(Self::Str),
@ -69,6 +70,7 @@ pub struct ConstantOptimizer {
#[cfg(feature = "constant-optimization")] #[cfg(feature = "constant-optimization")]
impl ConstantOptimizer { impl ConstantOptimizer {
#[inline] #[inline]
#[must_use]
pub fn new() -> Self { pub fn new() -> Self {
Self { _priv: () } Self { _priv: () }
} }
@ -85,33 +87,22 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> { fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node { match node.node {
crate::ExprKind::Tuple { elts, ctx } => { crate::ExprKind::Tuple { elts, ctx } => {
let elts = elts let elts =
.into_iter() elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
.map(|x| self.fold_expr(x)) let expr =
.collect::<Result<Vec<_>, _>>()?; if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
let expr = if elts let tuple = elts
.iter() .into_iter()
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) .map(|e| match e.node {
{ crate::ExprKind::Constant { value, .. } => value,
let tuple = elts _ => unreachable!(),
.into_iter() })
.map(|e| match e.node { .collect();
crate::ExprKind::Constant { value, .. } => value, crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
_ => unreachable!(), } else {
}) crate::ExprKind::Tuple { elts, ctx }
.collect(); };
crate::ExprKind::Constant { Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
value: Constant::Tuple(tuple),
kind: None,
}
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr {
node: expr,
custom: node.custom,
location: node.location,
})
} }
_ => crate::fold::fold_expr(self, node), _ => crate::fold::fold_expr(self, node),
} }
@ -127,7 +118,7 @@ mod tests {
use crate::fold::Fold; use crate::fold::Fold;
use crate::*; use crate::*;
let location = Location::new(0, 0, Default::default()); let location = Location::new(0, 0, FileName::default());
let custom = (); let custom = ();
let ast = Located { let ast = Located {
location, location,
@ -138,18 +129,12 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 1.into(), kind: None },
value: 1.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 2.into(), kind: None },
value: 2.into(),
kind: None,
},
}, },
Located { Located {
location, location,
@ -160,26 +145,17 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 3.into(), kind: None },
value: 3.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 4.into(), kind: None },
value: 4.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 5.into(), kind: None },
value: 5.into(),
kind: None,
},
}, },
], ],
}, },
@ -187,9 +163,7 @@ mod tests {
], ],
}, },
}; };
let new_ast = ConstantOptimizer::new() let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
.fold_expr(ast)
.unwrap_or_else(|e| match e {});
assert_eq!( assert_eq!(
new_ast, new_ast,
Located { Located {
@ -199,11 +173,7 @@ mod tests {
value: Constant::Tuple(vec![ value: Constant::Tuple(vec![
1.into(), 1.into(),
2.into(), 2.into(),
Constant::Tuple(vec![ Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
3.into(),
4.into(),
5.into(),
])
]), ]),
kind: None kind: None
}, },

View File

@ -64,11 +64,4 @@ macro_rules! simple_fold {
}; };
} }
simple_fold!( simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);
usize,
String,
bool,
StrRef,
constant::Constant,
constant::ConversionFlag
);

View File

@ -2,6 +2,7 @@ use crate::{Constant, ExprKind};
impl<U> ExprKind<U> { impl<U> ExprKind<U> {
/// Returns a short name for the node suitable for use in error messages. /// Returns a short name for the node suitable for use in error messages.
#[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self { match self {
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => { ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
@ -34,10 +35,7 @@ impl<U> ExprKind<U> {
ExprKind::Starred { .. } => "starred", ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice", ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => { ExprKind::JoinedStr { values } => {
if values if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
.iter()
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
{
"f-string expression" "f-string expression"
} else { } else {
"literal" "literal"

View File

@ -1,5 +1,12 @@
#[macro_use] #![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
extern crate lazy_static; #![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
mod ast_gen; mod ast_gen;
mod constant; mod constant;
@ -9,6 +16,6 @@ mod impls;
mod location; mod location;
pub use ast_gen::*; pub use ast_gen::*;
pub use location::{Location, FileName}; pub use location::{FileName, Location};
pub type Suite<U = ()> = Vec<Stmt<U>>; pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,6 +1,6 @@
//! Datatypes to support source location information. //! Datatypes to support source location information.
use std::cmp::Ordering;
use crate::ast_gen::StrRef; use crate::ast_gen::StrRef;
use std::cmp::Ordering;
use std::fmt; use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
@ -22,7 +22,7 @@ impl From<String> for FileName {
pub struct Location { pub struct Location {
pub row: usize, pub row: usize,
pub column: usize, pub column: usize,
pub file: FileName pub file: FileName,
} }
impl fmt::Display for Location { impl fmt::Display for Location {
@ -35,12 +35,12 @@ impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string()); let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal { if file_cmp != Ordering::Equal {
return file_cmp return file_cmp;
} }
let row_cmp = self.row.cmp(&other.row); let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal { if row_cmp != Ordering::Equal {
return row_cmp return row_cmp;
} }
self.column.cmp(&other.column) self.column.cmp(&other.column)
@ -76,23 +76,22 @@ impl Location {
) )
} }
} }
Visualize { Visualize { loc: *self, line, desc }
loc: *self,
line,
desc,
}
} }
} }
impl Location { impl Location {
#[must_use]
pub fn new(row: usize, column: usize, file: FileName) -> Self { pub fn new(row: usize, column: usize, file: FileName) -> Self {
Location { row, column, file } Location { row, column, file }
} }
#[must_use]
pub fn row(&self) -> usize { pub fn row(&self) -> usize {
self.row self.row
} }
#[must_use]
pub fn column(&self) -> usize { pub fn column(&self) -> usize {
self.column self.column
} }

View File

@ -4,17 +4,26 @@ version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2021" edition = "2021"
[features]
default = ["derive"]
derive = ["dep:nac3core_derive"]
no-escape-analysis = []
[dependencies] [dependencies]
itertools = "0.12" itertools = "0.13"
crossbeam = "0.8" crossbeam = "0.8"
indexmap = "2.6"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.5" rayon = "1.10"
nac3core_derive = { path = "nac3core_derive", optional = true }
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
strum = "0.26"
strum_macros = "0.26"
[dependencies.inkwell] [dependencies.inkwell]
version = "0.2" version = "0.5"
default-features = false default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"] features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies] [dev-dependencies]
test-case = "1.2.0" test-case = "1.2.0"

View File

@ -1,4 +1,3 @@
use regex::Regex;
use std::{ use std::{
env, env,
fs::File, fs::File,
@ -7,35 +6,58 @@ use std::{
process::{Command, Stdio}, process::{Command, Stdio},
}; };
use regex::Regex;
fn main() { fn main() {
const FILE: &str = "src/codegen/irrt/irrt.c"; let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("irrt");
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
const FLAG: &[&str] = &[ let mut flags: Vec<&str> = vec![
"--target=wasm32", "--target=wasm32",
FILE, "-x",
"-O3", "c++",
"-std=c++20",
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
"-emit-llvm", "-emit-llvm",
"-S", "-S",
"-Wall", "-Wall",
"-Wextra", "-Wextra",
"-o", "-o",
"-", "-",
"-I",
irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(),
]; ];
println!("cargo:rerun-if-changed={FILE}"); match env::var("PROFILE").as_deref() {
let out_dir = env::var("OUT_DIR").unwrap(); Ok("debug") => {
let out_path = Path::new(&out_dir); flags.push("-O0");
flags.push("-DIRRT_DEBUG_ASSERT");
}
Ok("release") => {
flags.push("-O3");
}
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
}
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output
let output = Command::new("clang-irrt") let output = Command::new("clang-irrt")
.args(FLAG) .args(flags)
.output() .output()
.map(|o| { .inspect(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap()); assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
}) })
.unwrap(); .unwrap();
@ -43,7 +65,17 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n"); let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len()); let mut filtered_output = String::with_capacity(output.len());
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); // Filter out irrelevant IR
//
// Regex:
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
for f in regex_filter.captures_iter(&output) { for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1); assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]); filtered_output.push_str(&f[0]);
@ -54,18 +86,22 @@ fn main() {
.unwrap() .unwrap()
.replace_all(&filtered_output, ""); .replace_all(&filtered_output, "");
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT"); // For debugging
if env::var("DEBUG_DUMP_IRRT").is_ok() { // Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
let mut file = File::create(out_path.join("irrt.ll")).unwrap(); const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap(); file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap();
} }
let mut llvm_as = Command::new("llvm-as-irrt") let mut llvm_as = Command::new("llvm-as-irrt")
.stdin(Stdio::piped()) .stdin(Stdio::piped())
.arg("-o") .arg("-o")
.arg(out_path.join("irrt.bc")) .arg(out_dir.join("irrt.bc"))
.spawn() .spawn()
.unwrap(); .unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();

10
nac3core/irrt/irrt.cpp Normal file
View File

@ -0,0 +1,10 @@
#include "irrt/exception.hpp"
#include "irrt/list.hpp"
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/range.hpp"
#include "irrt/slice.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/ndarray/iter.hpp"
#include "irrt/ndarray/indexing.hpp"

View File

@ -0,0 +1,9 @@
#pragma once
#include "irrt/int_types.hpp"
template<typename SizeT>
struct CSlice {
void* base;
SizeT len;
};

View File

@ -0,0 +1,25 @@
#pragma once
// Set in nac3core/build.rs
#ifdef IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
#define debug_assert_eq(SizeT, lhs, rhs) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
}
#define debug_assert(SizeT, expr) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
}

View File

@ -0,0 +1,85 @@
#pragma once
#include "irrt/cslice.hpp"
#include "irrt/int_types.hpp"
/**
* @brief The int type of ARTIQ exception IDs.
*/
using ExceptionId = int32_t;
/*
* Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
*/
extern "C" {
ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_TYPE_ERROR;
}
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
namespace {
/**
* @brief NAC3's Exception struct
*/
template<typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> filename;
int32_t line;
int32_t column;
CSlice<SizeT> function;
CSlice<SizeT> msg;
int64_t params[3];
};
constexpr int64_t NO_PARAM = 0;
template<typename SizeT>
void _raise_exception_helper(ExceptionId id,
const char* filename,
int32_t line,
const char* function,
const char* msg,
int64_t param0,
int64_t param1,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<void*>(const_cast<char*>(filename)),
.len = static_cast<SizeT>(__builtin_strlen(filename))},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<void*>(const_cast<char*>(function)),
.len = static_cast<SizeT>(__builtin_strlen(function))},
.msg = {.base = reinterpret_cast<void*>(const_cast<char*>(msg)),
.len = static_cast<SizeT>(__builtin_strlen(msg))},
};
e.params[0] = param0;
e.params[1] = param1;
e.params[2] = param2;
__nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable();
}
} // namespace
/**
* @brief Raise an exception with location details (location in the IRRT source files).
* @param SizeT The runtime `size_t` type.
* @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message.
*
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused.
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)

View File

@ -0,0 +1,27 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wdeprecated-type"
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#pragma clang diagnostic pop
#endif
// NDArray indices are always `uint32_t`.
using NDIndexInt = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -0,0 +1,81 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
extern "C" {
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
void* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
void* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_start * size,
static_cast<uint8_t*>(src_arr) + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + (dest_start + src_len) * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
void* tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind, static_cast<uint8_t*>(src_arr) + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 4,
static_cast<uint8_t*>(src_arr) + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * 8,
static_cast<uint8_t*>(src_arr) + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(src_arr) + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(static_cast<uint8_t*>(dest_arr) + dest_ind * size,
static_cast<uint8_t*>(dest_arr) + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
} // extern "C"

View File

@ -0,0 +1,93 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
} // namespace

View File

@ -0,0 +1,13 @@
#pragma once
namespace {
template<typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
} // namespace

View File

@ -0,0 +1,151 @@
#pragma once
#include "irrt/int_types.hpp"
// TODO: To be deleted since NDArray with strides is done.
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
SizeT num_dims,
const NDIndexInt* indices,
SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
uint64_t num_dims,
const NDIndexInt* indices,
uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndexInt* in_idx,
NDIndexInt* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
} // namespace

View File

@ -0,0 +1,342 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
namespace ndarray {
namespace basic {
/**
* @brief Assert that `shape` does not contain negative dimensions.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape to check on
*/
template<typename SizeT>
void assert_shape_no_negative(SizeT ndims, const SizeT* shape) {
for (SizeT axis = 0; axis < ndims; axis++) {
if (shape[axis] < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR,
"negative dimensions are not allowed; axis {0} "
"has dimension {1}",
axis, shape[axis], NO_PARAM);
}
}
}
/**
* @brief Assert that two shapes are the same in the context of writing output to an ndarray.
*/
template<typename SizeT>
void assert_output_shape_same(SizeT ndarray_ndims,
const SizeT* ndarray_shape,
SizeT output_ndims,
const SizeT* output_shape) {
if (ndarray_ndims != output_ndims) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot write output of ndims {0} to an ndarray with ndims {1}",
output_ndims, ndarray_ndims, NO_PARAM);
}
for (SizeT axis = 0; axis < ndarray_ndims; axis++) {
if (ndarray_shape[axis] != output_shape[axis]) {
// There is no corresponding NumPy error message like this.
raise_exception(SizeT, EXN_VALUE_ERROR,
"Mismatched dimensions on axis {0}, output has "
"dimension {1}, but destination ndarray has dimension {2}.",
axis, output_shape[axis], ndarray_shape[axis]);
}
}
}
/**
* @brief Return the number of elements of an ndarray given its shape.
*
* @param ndims Number of dimensions in `shape`
* @param shape The shape of the ndarray
*/
template<typename SizeT>
SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT axis = 0; axis < ndims; axis++)
size *= shape[axis];
return size;
}
/**
* @brief Compute the array indices of the `nth` (0-based) element of an ndarray given only its shape.
*
* @param ndims Number of elements in `shape` and `indices`
* @param shape The shape of the ndarray
* @param indices The returned indices indexing the ndarray with shape `shape`.
* @param nth The index of the element of interest.
*/
template<typename SizeT>
void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
SizeT dim = shape[axis];
indices[axis] = nth % dim;
nth /= dim;
}
}
/**
* @brief Return the number of elements of an `ndarray`
*
* This function corresponds to `<an_ndarray>.size`
*/
template<typename SizeT>
SizeT size(const NDArray<SizeT>* ndarray) {
return calc_size_from_shape(ndarray->ndims, ndarray->shape);
}
/**
* @brief Return of the number of its content of an `ndarray`.
*
* This function corresponds to `<an_ndarray>.nbytes`.
*/
template<typename SizeT>
SizeT nbytes(const NDArray<SizeT>* ndarray) {
return size(ndarray) * ndarray->itemsize;
}
/**
* @brief Get the `len()` of an ndarray, and asserts that `ndarray` is a sized object.
*
* This function corresponds to `<an_ndarray>.__len__`.
*
* @param dst_length The length.
*/
template<typename SizeT>
SizeT len(const NDArray<SizeT>* ndarray) {
if (ndarray->ndims != 0) {
return ndarray->shape[0];
}
// numpy prohibits `__len__` on unsized objects
raise_exception(SizeT, EXN_TYPE_ERROR, "len() of unsized object", NO_PARAM, NO_PARAM, NO_PARAM);
__builtin_unreachable();
}
/**
* @brief Return a boolean indicating if `ndarray` is (C-)contiguous.
*
* You may want to see ndarray's rules for C-contiguity:
* https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
*/
template<typename SizeT>
bool is_c_contiguous(const NDArray<SizeT>* ndarray) {
// References:
// - tinynumpy's implementation:
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L102
// - ndarray's flags["C_CONTIGUOUS"]:
// https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags
// - ndarray's rules for C-contiguity:
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45
// From
// https://github.com/numpy/numpy/blob/df256d0d2f3bc6833699529824781c58f9c6e697/numpy/core/src/multiarray/flagsobject.c#L95C1-L99C45:
//
// The traditional rule is that for an array to be flagged as C contiguous,
// the following must hold:
//
// strides[-1] == itemsize
// strides[i] == shape[i+1] * strides[i + 1]
// [...]
// According to these rules, a 0- or 1-dimensional array is either both
// C- and F-contiguous, or neither; and an array with 2+ dimensions
// can be C- or F- contiguous, or neither, but not both. Though there
// there are exceptions for arrays with zero or one item, in the first
// case the check is relaxed up to and including the first dimension
// with shape[i] == 0. In the second case `strides == itemsize` will
// can be true for all dimensions and both flags are set.
if (ndarray->ndims == 0) {
return true;
}
if (ndarray->strides[ndarray->ndims - 1] != ndarray->itemsize) {
return false;
}
for (SizeT i = 1; i < ndarray->ndims; i++) {
SizeT axis_i = ndarray->ndims - i - 1;
if (ndarray->strides[axis_i] != ndarray->shape[axis_i + 1] * ndarray->strides[axis_i + 1]) {
return false;
}
}
return true;
}
/**
* @brief Return the pointer to the element indexed by `indices` along the ndarray's axes.
*
* This function does no bound check.
*/
template<typename SizeT>
void* get_pelement_by_indices(const NDArray<SizeT>* ndarray, const SizeT* indices) {
void* element = ndarray->data;
for (SizeT dim_i = 0; dim_i < ndarray->ndims; dim_i++)
element = static_cast<uint8_t*>(element) + indices[dim_i] * ndarray->strides[dim_i];
return element;
}
/**
* @brief Return the pointer to the nth (0-based) element of `ndarray` in flattened view.
*
* This function does no bound check.
*/
template<typename SizeT>
void* get_nth_pelement(const NDArray<SizeT>* ndarray, SizeT nth) {
void* element = ndarray->data;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
SizeT dim = ndarray->shape[axis];
element = static_cast<uint8_t*>(element) + ndarray->strides[axis] * (nth % dim);
nth /= dim;
}
return element;
}
/**
* @brief Update the strides of an ndarray given an ndarray `shape` to be contiguous.
*
* You might want to read https://ajcr.net/stride-guide-part-1/.
*/
template<typename SizeT>
void set_strides_by_shape(NDArray<SizeT>* ndarray) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndarray->ndims; i++) {
SizeT axis = ndarray->ndims - i - 1;
ndarray->strides[axis] = stride_product * ndarray->itemsize;
stride_product *= ndarray->shape[axis];
}
}
/**
* @brief Set an element in `ndarray`.
*
* @param pelement Pointer to the element in `ndarray` to be set.
* @param pvalue Pointer to the value `pelement` will be set to.
*/
template<typename SizeT>
void set_pelement_value(NDArray<SizeT>* ndarray, void* pelement, const void* pvalue) {
__builtin_memcpy(pelement, pvalue, ndarray->itemsize);
}
/**
* @brief Copy data from one ndarray to another of the exact same size and itemsize.
*
* Both ndarrays will be viewed in their flatten views when copying the elements.
*/
template<typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// TODO: Make this faster with memcpy when we see a contiguous segment.
// TODO: Handle overlapping.
debug_assert_eq(SizeT, src_ndarray->itemsize, dst_ndarray->itemsize);
for (SizeT i = 0; i < size(src_ndarray); i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
}
}
} // namespace basic
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::basic;
void __nac3_ndarray_util_assert_shape_no_negative(int32_t ndims, int32_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_shape_no_negative64(int64_t ndims, int64_t* shape) {
assert_shape_no_negative(ndims, shape);
}
void __nac3_ndarray_util_assert_output_shape_same(int32_t ndarray_ndims,
const int32_t* ndarray_shape,
int32_t output_ndims,
const int32_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
void __nac3_ndarray_util_assert_output_shape_same64(int64_t ndarray_ndims,
const int64_t* ndarray_shape,
int64_t output_ndims,
const int64_t* output_shape) {
assert_output_shape_same(ndarray_ndims, ndarray_shape, output_ndims, output_shape);
}
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return size(ndarray);
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return size(ndarray);
}
uint32_t __nac3_ndarray_nbytes(NDArray<int32_t>* ndarray) {
return nbytes(ndarray);
}
uint64_t __nac3_ndarray_nbytes64(NDArray<int64_t>* ndarray) {
return nbytes(ndarray);
}
int32_t __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
return len(ndarray);
}
int64_t __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
return len(ndarray);
}
bool __nac3_ndarray_is_c_contiguous(NDArray<int32_t>* ndarray) {
return is_c_contiguous(ndarray);
}
bool __nac3_ndarray_is_c_contiguous64(NDArray<int64_t>* ndarray) {
return is_c_contiguous(ndarray);
}
void* __nac3_ndarray_get_nth_pelement(const NDArray<int32_t>* ndarray, int32_t nth) {
return get_nth_pelement(ndarray, nth);
}
void* __nac3_ndarray_get_nth_pelement64(const NDArray<int64_t>* ndarray, int64_t nth) {
return get_nth_pelement(ndarray, nth);
}
void* __nac3_ndarray_get_pelement_by_indices(const NDArray<int32_t>* ndarray, int32_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
void* __nac3_ndarray_get_pelement_by_indices64(const NDArray<int64_t>* ndarray, int64_t* indices) {
return get_pelement_by_indices(ndarray, indices);
}
void __nac3_ndarray_set_strides_by_shape(NDArray<int32_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_set_strides_by_shape64(NDArray<int64_t>* ndarray) {
set_strides_by_shape(ndarray);
}
void __nac3_ndarray_copy_data(NDArray<int32_t>* src_ndarray, NDArray<int32_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
void __nac3_ndarray_copy_data64(NDArray<int64_t>* src_ndarray, NDArray<int64_t>* dst_ndarray) {
copy_data(src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,51 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
/**
* @brief The NDArray object
*
* Official numpy implementation:
* https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst#pyarrayinterface
*
* Note that this implementation is based on `PyArrayInterface` rather of `PyArrayObject`. The
* difference between `PyArrayInterface` and `PyArrayObject` (relevant to our implementation) is
* that `PyArrayInterface` *has* `itemsize` and uses `void*` for its `data`, whereas `PyArrayObject`
* does not require `itemsize` (probably using `strides[-1]` instead) and uses `char*` for its
* `data`. There are also minor differences in the struct layout.
*/
template<typename SizeT>
struct NDArray {
/**
* @brief The number of bytes of a single element in `data`.
*/
SizeT itemsize;
/**
* @brief The number of dimensions of this shape.
*/
SizeT ndims;
/**
* @brief The NDArray shape, with length equal to `ndims`.
*
* Note that it may contain 0.
*/
SizeT* shape;
/**
* @brief Array strides, with length equal to `ndims`
*
* The stride values are in units of bytes, not number of elements.
*
* Note that `strides` can have negative values or contain 0.
*/
SizeT* strides;
/**
* @brief The underlying data this `ndarray` is pointing to.
*/
void* data;
};
} // namespace

View File

@ -0,0 +1,220 @@
#pragma once
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/ndarray/basic.hpp"
#include "irrt/ndarray/def.hpp"
#include "irrt/range.hpp"
#include "irrt/slice.hpp"
namespace {
typedef uint8_t NDIndexType;
/**
* @brief A single element index
*
* `data` points to a `int32_t`.
*/
const NDIndexType ND_INDEX_TYPE_SINGLE_ELEMENT = 0;
/**
* @brief A slice index
*
* `data` points to a `Slice<int32_t>`.
*/
const NDIndexType ND_INDEX_TYPE_SLICE = 1;
/**
* @brief `np.newaxis` / `None`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_NEWAXIS = 2;
/**
* @brief `Ellipsis` / `...`
*
* `data` is unused.
*/
const NDIndexType ND_INDEX_TYPE_ELLIPSIS = 3;
/**
* @brief An index used in ndarray indexing
*
* That is:
* ```
* my_ndarray[::-1, 3, ..., np.newaxis]
* ^^^^ ^ ^^^ ^^^^^^^^^^ each of these is represented by an NDIndex.
* ```
*/
struct NDIndex {
/**
* @brief Enum tag to specify the type of index.
*
* Please see the comment of each enum constant.
*/
NDIndexType type;
/**
* @brief The accompanying data associated with `type`.
*
* Please see the comment of each enum constant.
*/
uint8_t* data;
};
} // namespace
namespace {
namespace ndarray {
namespace indexing {
/**
* @brief Perform ndarray "basic indexing" (https://numpy.org/doc/stable/user/basics.indexing.html#basic-indexing)
*
* This function is very similar to performing `dst_ndarray = src_ndarray[indices]` in Python.
*
* This function also does proper assertions on `indices` to check for out of bounds access and more.
*
* # Notes on `dst_ndarray`
* The caller is responsible for allocating space for the resulting ndarray.
* Here is what this function expects from `dst_ndarray` when called:
* - `dst_ndarray->data` does not have to be initialized.
* - `dst_ndarray->itemsize` does not have to be initialized.
* - `dst_ndarray->ndims` must be initialized, and it must be equal to the expected `ndims` of the `dst_ndarray` after
* indexing `src_ndarray` with `indices`.
* - `dst_ndarray->shape` must be allocated, through it can contain uninitialized values.
* - `dst_ndarray->strides` must be allocated, through it can contain uninitialized values.
* When this function call ends:
* - `dst_ndarray->data` is set to `src_ndarray->data`.
* - `dst_ndarray->itemsize` is set to `src_ndarray->itemsize`.
* - `dst_ndarray->ndims` is unchanged.
* - `dst_ndarray->shape` is updated according to how `src_ndarray` is indexed.
* - `dst_ndarray->strides` is updated accordingly by how ndarray indexing works.
*
* @param indices indices to index `src_ndarray`, ordered in the same way you would write them in Python.
* @param src_ndarray The NDArray to be indexed.
* @param dst_ndarray The resulting NDArray after indexing. Further details in the comments above,
*/
template<typename SizeT>
void index(SizeT num_indices, const NDIndex* indices, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// Validate `indices`.
// Expected value of `dst_ndarray->ndims`.
SizeT expected_dst_ndims = src_ndarray->ndims;
// To check for "too many indices for array: array is ?-dimensional, but ? were indexed"
SizeT num_indexed = 0;
// There may be ellipsis `...` in `indices`. There can only be 0 or 1 ellipsis.
SizeT num_ellipsis = 0;
for (SizeT i = 0; i < num_indices; i++) {
if (indices[i].type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
expected_dst_ndims--;
num_indexed++;
} else if (indices[i].type == ND_INDEX_TYPE_SLICE) {
num_indexed++;
} else if (indices[i].type == ND_INDEX_TYPE_NEWAXIS) {
expected_dst_ndims++;
} else if (indices[i].type == ND_INDEX_TYPE_ELLIPSIS) {
num_ellipsis++;
if (num_ellipsis > 1) {
raise_exception(SizeT, EXN_INDEX_ERROR, "an index can only have a single ellipsis ('...')", NO_PARAM,
NO_PARAM, NO_PARAM);
}
} else {
__builtin_unreachable();
}
}
debug_assert_eq(SizeT, expected_dst_ndims, dst_ndarray->ndims);
if (src_ndarray->ndims - num_indexed < 0) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"too many indices for array: array is {0}-dimensional, "
"but {1} were indexed",
src_ndarray->ndims, num_indices, NO_PARAM);
}
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// Reference code:
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (int32_t i = 0; i < num_indices; i++) {
const NDIndex* index = &indices[i];
if (index->type == ND_INDEX_TYPE_SINGLE_ELEMENT) {
SizeT input = (SizeT) * ((int32_t*)index->data);
SizeT k = slice::resolve_index_in_length(src_ndarray->shape[src_axis], input);
if (k == -1) {
raise_exception(SizeT, EXN_INDEX_ERROR,
"index {0} is out of bounds for axis {1} "
"with size {2}",
input, src_axis, src_ndarray->shape[src_axis]);
}
dst_ndarray->data = static_cast<uint8_t*>(dst_ndarray->data) + k * src_ndarray->strides[src_axis];
src_axis++;
} else if (index->type == ND_INDEX_TYPE_SLICE) {
Slice<int32_t>* slice = (Slice<int32_t>*)index->data;
Range<int32_t> range = slice->indices_checked<SizeT>(src_ndarray->shape[src_axis]);
dst_ndarray->data = static_cast<uint8_t*>(dst_ndarray->data) + (SizeT)range.start * src_ndarray->strides[src_axis];
dst_ndarray->strides[dst_axis] = ((SizeT)range.step) * src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = (SizeT)range.len<SizeT>();
dst_axis++;
src_axis++;
} else if (index->type == ND_INDEX_TYPE_NEWAXIS) {
dst_ndarray->strides[dst_axis] = 0;
dst_ndarray->shape[dst_axis] = 1;
dst_axis++;
} else if (index->type == ND_INDEX_TYPE_ELLIPSIS) {
// The number of ':' entries this '...' implies.
SizeT ellipsis_size = src_ndarray->ndims - num_indexed;
for (SizeT j = 0; j < ellipsis_size; j++) {
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_axis++;
src_axis++;
}
} else {
__builtin_unreachable();
}
}
for (; dst_axis < dst_ndarray->ndims; dst_axis++, src_axis++) {
dst_ndarray->shape[dst_axis] = src_ndarray->shape[src_axis];
dst_ndarray->strides[dst_axis] = src_ndarray->strides[src_axis];
}
debug_assert_eq(SizeT, src_ndarray->ndims, src_axis);
debug_assert_eq(SizeT, dst_ndarray->ndims, dst_axis);
}
} // namespace indexing
} // namespace ndarray
} // namespace
extern "C" {
using namespace ndarray::indexing;
void __nac3_ndarray_index(int32_t num_indices,
NDIndex* indices,
NDArray<int32_t>* src_ndarray,
NDArray<int32_t>* dst_ndarray) {
index(num_indices, indices, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_index64(int64_t num_indices,
NDIndex* indices,
NDArray<int64_t>* src_ndarray,
NDArray<int64_t>* dst_ndarray) {
index(num_indices, indices, src_ndarray, dst_ndarray);
}
}

View File

@ -0,0 +1,146 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/ndarray/def.hpp"
namespace {
/**
* @brief Helper struct to enumerate through an ndarray *efficiently*.
*
* Example usage (in pseudo-code):
* ```
* // Suppose my_ndarray has been initialized, with shape [2, 3] and dtype `double`
* NDIter nditer;
* nditer.initialize(my_ndarray);
* while (nditer.has_element()) {
* // This body is run 6 (= my_ndarray.size) times.
*
* // [0, 0] -> [0, 1] -> [0, 2] -> [1, 0] -> [1, 1] -> [1, 2] -> end
* print(nditer.indices);
*
* // 0 -> 1 -> 2 -> 3 -> 4 -> 5
* print(nditer.nth);
*
* // <1st element> -> <2nd element> -> ... -> <6th element> -> end
* print(*((double *) nditer.element))
*
* nditer.next(); // Go to next element.
* }
* ```
*
* Interesting cases:
* - If `my_ndarray.ndims` == 0, there is one iteration.
* - If `my_ndarray.shape` contains zeroes, there are no iterations.
*/
template<typename SizeT>
struct NDIter {
// Information about the ndarray being iterated over.
SizeT ndims;
SizeT* shape;
SizeT* strides;
/**
* @brief The current indices.
*
* Must be allocated by the caller.
*/
SizeT* indices;
/**
* @brief The nth (0-based) index of the current indices.
*
* Initially this is 0.
*/
SizeT nth;
/**
* @brief Pointer to the current element.
*
* Initially this points to first element of the ndarray.
*/
void* element;
/**
* @brief Cache for the product of shape.
*
* Could be 0 if `shape` has 0s in it.
*/
SizeT size;
void initialize(SizeT ndims, SizeT* shape, SizeT* strides, void* element, SizeT* indices) {
this->ndims = ndims;
this->shape = shape;
this->strides = strides;
this->indices = indices;
this->element = element;
// Compute size
this->size = 1;
for (SizeT i = 0; i < ndims; i++) {
this->size *= shape[i];
}
// `indices` starts on all 0s.
for (SizeT axis = 0; axis < ndims; axis++)
indices[axis] = 0;
nth = 0;
}
void initialize_by_ndarray(NDArray<SizeT>* ndarray, SizeT* indices) {
// NOTE: ndarray->data is pointing to the first element, and `NDIter`'s `element` should also point to the first
// element as well.
this->initialize(ndarray->ndims, ndarray->shape, ndarray->strides, ndarray->data, indices);
}
// Is the current iteration valid?
// If true, then `element`, `indices` and `nth` contain details about the current element.
bool has_element() { return nth < size; }
// Go to the next element.
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT axis = ndims - i - 1;
indices[axis]++;
if (indices[axis] >= shape[axis]) {
indices[axis] = 0;
// TODO: There is something called backstrides to speedup iteration.
// See https://ajcr.net/stride-guide-part-1/, and
// https://docs.scipy.org/doc/numpy-1.13.0/reference/c-api.types-and-structures.html#c.PyArrayIterObject.PyArrayIterObject.backstrides.
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) - strides[axis] * (shape[axis] - 1));
} else {
element = static_cast<void*>(reinterpret_cast<uint8_t*>(element) + strides[axis]);
break;
}
}
nth++;
}
};
} // namespace
extern "C" {
void __nac3_nditer_initialize(NDIter<int32_t>* iter, NDArray<int32_t>* ndarray, int32_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
void __nac3_nditer_initialize64(NDIter<int64_t>* iter, NDArray<int64_t>* ndarray, int64_t* indices) {
iter->initialize_by_ndarray(ndarray, indices);
}
bool __nac3_nditer_has_element(NDIter<int32_t>* iter) {
return iter->has_element();
}
bool __nac3_nditer_has_element64(NDIter<int64_t>* iter) {
return iter->has_element();
}
void __nac3_nditer_next(NDIter<int32_t>* iter) {
iter->next();
}
void __nac3_nditer_next64(NDIter<int64_t>* iter) {
iter->next();
}
}

View File

@ -0,0 +1,47 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/int_types.hpp"
namespace {
namespace range {
template<typename T>
T len(T start, T stop, T step) {
// Reference:
// https://github.com/python/cpython/blob/9dbd12375561a393eaec4b21ee4ac568a407cdb0/Objects/rangeobject.c#L933
if (step > 0 && start < stop)
return 1 + (stop - 1 - start) / step;
else if (step < 0 && start > stop)
return 1 + (start - 1 - stop) / (-step);
else
return 0;
}
} // namespace range
/**
* @brief A Python range.
*/
template<typename T>
struct Range {
T start;
T stop;
T step;
/**
* @brief Calculate the `len()` of this range.
*/
template<typename SizeT>
T len() {
debug_assert(SizeT, step != 0);
return range::len(start, stop, step);
}
};
} // namespace
extern "C" {
using namespace range;
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
return len(start, end, step);
}
}

View File

@ -0,0 +1,156 @@
#pragma once
#include "irrt/debug.hpp"
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
#include "irrt/range.hpp"
namespace {
namespace slice {
/**
* @brief Resolve a possibly negative index in a list of a known length.
*
* Returns -1 if the resolved index is out of the list's bounds.
*/
template<typename T>
T resolve_index_in_length(T length, T index) {
T resolved = index < 0 ? length + index : index;
if (0 <= resolved && resolved < length) {
return resolved;
} else {
return -1;
}
}
/**
* @brief Resolve a slice as a range.
*
* This is equivalent to `range(*slice(start, stop, step).indices(length))` in Python.
*/
template<typename T>
void indices(bool start_defined,
T start,
bool stop_defined,
T stop,
bool step_defined,
T step,
T length,
T* range_start,
T* range_stop,
T* range_step) {
// Reference: https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
*range_step = step_defined ? step : 1;
bool step_is_negative = *range_step < 0;
T lower, upper;
if (step_is_negative) {
lower = -1;
upper = length - 1;
} else {
lower = 0;
upper = length;
}
if (start_defined) {
*range_start = start < 0 ? max(lower, start + length) : min(upper, start);
} else {
*range_start = step_is_negative ? upper : lower;
}
if (stop_defined) {
*range_stop = stop < 0 ? max(lower, stop + length) : min(upper, stop);
} else {
*range_stop = step_is_negative ? lower : upper;
}
}
} // namespace slice
/**
* @brief A Python-like slice with **unresolved** indices.
*/
template<typename T>
struct Slice {
bool start_defined;
T start;
bool stop_defined;
T stop;
bool step_defined;
T step;
Slice() { this->reset(); }
void reset() {
this->start_defined = false;
this->stop_defined = false;
this->step_defined = false;
}
void set_start(T start) {
this->start_defined = true;
this->start = start;
}
void set_stop(T stop) {
this->stop_defined = true;
this->stop = stop;
}
void set_step(T step) {
this->step_defined = true;
this->step = step;
}
/**
* @brief Resolve this slice as a range.
*
* In Python, this would be `range(*slice(start, stop, step).indices(length))`.
*/
template<typename SizeT>
Range<T> indices(T length) {
// Reference:
// https://github.com/python/cpython/blob/main/Objects/sliceobject.c#L388
debug_assert(SizeT, length >= 0);
Range<T> result;
slice::indices(start_defined, start, stop_defined, stop, step_defined, step, length, &result.start,
&result.stop, &result.step);
return result;
}
/**
* @brief Like `.indices()` but with assertions.
*/
template<typename SizeT>
Range<T> indices_checked(T length) {
// TODO: Switch to `SizeT length`
if (length < 0) {
raise_exception(SizeT, EXN_VALUE_ERROR, "length should not be negative, got {0}", length, NO_PARAM,
NO_PARAM);
}
if (this->step_defined && this->step == 0) {
raise_exception(SizeT, EXN_VALUE_ERROR, "slice step cannot be zero", NO_PARAM, NO_PARAM, NO_PARAM);
}
return this->indices<SizeT>(length);
}
};
} // namespace
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
}

View File

@ -0,0 +1,21 @@
[package]
name = "nac3core_derive"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[[test]]
name = "structfields_tests"
path = "tests/structfields_test.rs"
[dev-dependencies]
nac3core = { path = ".." }
trybuild = { version = "1.0", features = ["diff"] }
[dependencies]
proc-macro2 = "1.0"
proc-macro-error = "1.0"
syn = "2.0"
quote = "1.0"

View File

@ -0,0 +1,320 @@
use proc_macro::TokenStream;
use proc_macro_error::{abort, proc_macro_error};
use quote::quote;
use syn::{
parse_macro_input, spanned::Spanned, Data, DataStruct, Expr, ExprField, ExprMethodCall,
ExprPath, GenericArgument, Ident, LitStr, Path, PathArguments, Type, TypePath,
};
/// Extracts all generic arguments of a [`Type`] into a [`Vec`].
///
/// Returns [`Some`] of a possibly-empty [`Vec`] if the path of `ty` matches with
/// `expected_ty_name`, otherwise returns [`None`].
fn extract_generic_args(expected_ty_name: &'static str, ty: &Type) -> Option<Vec<GenericArgument>> {
let Type::Path(TypePath { qself: None, path, .. }) = ty else {
return None;
};
let segments = &path.segments;
if segments.len() != 1 {
return None;
};
let segment = segments.iter().next().unwrap();
if segment.ident != expected_ty_name {
return None;
}
let PathArguments::AngleBracketed(path_args) = &segment.arguments else {
return Some(Vec::new());
};
let args = &path_args.args;
Some(args.iter().cloned().collect::<Vec<_>>())
}
/// Maps a `path` matching one of the `target_idents` into the `replacement` [`Ident`].
fn map_path_to_ident(path: &Path, target_idents: &[&str], replacement: &str) -> Option<Ident> {
path.require_ident()
.ok()
.filter(|ident| target_idents.iter().any(|target| ident == target))
.map(|ident| Ident::new(replacement, ident.span()))
}
/// Extracts the left-hand side of a dot-expression.
fn extract_dot_operand(expr: &Expr) -> Option<&Expr> {
match expr {
Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) => Some(operand),
_ => None,
}
}
/// Replaces the top-level receiver of a dot-expression with an [`Ident`], returning `Some(&mut expr)` if the
/// replacement is performed.
///
/// The top-level receiver is the left-most receiver expression, e.g. the top-level receiver of `a.b.c.foo()` is `a`.
fn replace_top_level_receiver(expr: &mut Expr, ident: Ident) -> Option<&mut Expr> {
if let Expr::MethodCall(ExprMethodCall { receiver: operand, .. })
| Expr::Field(ExprField { base: operand, .. }) = expr
{
return if extract_dot_operand(operand).is_some() {
if replace_top_level_receiver(operand, ident).is_some() {
Some(expr)
} else {
None
}
} else {
*operand = Box::new(Expr::Path(ExprPath {
attrs: Vec::default(),
qself: None,
path: ident.into(),
}));
Some(expr)
};
}
None
}
/// Iterates all operands to the left-hand side of the `.` of an [expression][`Expr`], i.e. the container operand of all
/// [`Expr::Field`] and the receiver operand of all [`Expr::MethodCall`].
///
/// The iterator will return the operand expressions in reverse order of appearance. For example, `a.b.c.func()` will
/// return `vec![c, b, a]`.
fn iter_dot_operands(expr: &Expr) -> impl Iterator<Item = &Expr> {
let mut o = extract_dot_operand(expr);
std::iter::from_fn(move || {
let this = o;
o = o.as_ref().and_then(|o| extract_dot_operand(o));
this
})
}
/// Normalizes a value expression for use when creating an instance of this structure, returning a
/// [`proc_macro2::TokenStream`] of tokens representing the normalized expression.
fn normalize_value_expr(expr: &Expr) -> proc_macro2::TokenStream {
match &expr {
Expr::Path(ExprPath { qself: None, path, .. }) => {
if let Some(ident) = map_path_to_ident(path, &["usize", "size_t"], "llvm_usize") {
quote! { #ident }
} else {
abort!(
path,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
Expr::Call(_) => {
quote! { ctx.#expr }
}
Expr::MethodCall(_) => {
let base_receiver = iter_dot_operands(expr).last();
match base_receiver {
// `usize.{...}`, `size_t.{...}` -> Rewrite the identifiers to `llvm_usize`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").is_some() =>
{
let ident =
map_path_to_ident(path, &["usize", "size_t"], "llvm_usize").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// `ctx.{...}`, `context.{...}` -> Rewrite the identifiers to `ctx`
Some(Expr::Path(ExprPath { qself: None, path, .. }))
if map_path_to_ident(path, &["ctx", "context"], "ctx").is_some() =>
{
let ident = map_path_to_ident(path, &["ctx", "context"], "ctx").unwrap();
let mut expr = expr.clone();
let expr = replace_top_level_receiver(&mut expr, ident).unwrap();
quote!(#expr)
}
// No reserved identifier prefix -> Prepend `ctx.` to the entire expression
_ => quote! { ctx.#expr },
}
}
_ => {
abort!(
expr,
format!(
"Expected one of `size_t`, `usize`, or an implicit call expression in #[value_type(...)], found {}",
quote!(#expr).to_string(),
)
)
}
}
}
/// Derives an implementation of `codegen::types::structure::StructFields`.
///
/// The benefit of using `#[derive(StructFields)]` is that all index- or order-dependent logic required by
/// `impl StructFields` is automatically generated by this implementation, including the field index as required by
/// `StructField::new` and the fields as returned by `StructFields::to_vec`.
///
/// # Prerequisites
///
/// In order to derive from [`StructFields`], you must implement (or derive) [`Eq`] and [`Copy`] as required by
/// `StructFields`.
///
/// Moreover, `#[derive(StructFields)]` can only be used for `struct`s with named fields, and may only contain fields
/// with either `StructField` or [`PhantomData`] types.
///
/// # Attributes for [`StructFields`]
///
/// Each `StructField` field must be declared with the `#[value_type(...)]` attribute. The argument of `value_type`
/// accepts one of the following:
///
/// - An expression returning an instance of `inkwell::types::BasicType` (with or without the receiver `ctx`/`context`).
/// For example, `context.i8_type()`, `ctx.i8_type()`, and `i8_type()` all refer to `i8`.
/// - The reserved identifiers `usize` and `size_t` referring to an `inkwell::types::IntType` of the platform-dependent
/// integer size. `usize` and `size_t` can also be used as the receiver to other method calls, e.g.
/// `usize.array_type(3)`.
///
/// # Example
///
/// The following is an example of an LLVM slice implemented using `#[derive(StructFields)]`.
///
/// ```rust,ignore
/// use nac3core::{
/// codegen::types::structure::StructField,
/// inkwell::{
/// values::{IntValue, PointerValue},
/// AddressSpace,
/// },
/// };
/// use nac3core_derive::StructFields;
///
/// // All classes that implement StructFields must also implement Eq and Copy
/// #[derive(PartialEq, Eq, Clone, Copy, StructFields)]
/// pub struct SliceValue<'ctx> {
/// // Declares ptr have a value type of i8*
/// //
/// // Can also be written as `ctx.i8_type().ptr_type(...)` or `context.i8_type().ptr_type(...)`
/// #[value_type(i8_type().ptr_type(AddressSpace::default()))]
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
///
/// // Declares len have a value type of usize, depending on the target compilation platform
/// #[value_type(usize)]
/// len: StructField<'ctx, IntValue<'ctx>>,
/// }
/// ```
#[proc_macro_derive(StructFields, attributes(value_type))]
#[proc_macro_error]
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as syn::DeriveInput);
let ident = &input.ident;
let Data::Struct(DataStruct { fields, .. }) = &input.data else {
abort!(input, "Only structs with named fields are supported");
};
if let Err(err_span) =
fields
.iter()
.try_for_each(|field| if field.ident.is_some() { Ok(()) } else { Err(field.span()) })
{
abort!(err_span, "Only structs with named fields are supported");
};
// Check if struct<'ctx>
if input.generics.params.len() != 1 {
abort!(input.generics, "Expected exactly 1 generic parameter")
}
let phantom_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_some())
.map(|field| field.ident.as_ref().unwrap())
.cloned()
.collect::<Vec<_>>();
let field_info = fields
.iter()
.filter(|field| extract_generic_args("PhantomData", &field.ty).is_none())
.map(|field| {
let ident = field.ident.as_ref().unwrap();
let ty = &field.ty;
let Some(_) = extract_generic_args("StructField", ty) else {
abort!(field, "Only StructField and PhantomData are allowed")
};
let attrs = &field.attrs;
let Some(value_type_attr) =
attrs.iter().find(|attr| attr.path().is_ident("value_type"))
else {
abort!(field, "Expected #[value_type(...)] attribute for field");
};
let Ok(value_type_expr) = value_type_attr.parse_args::<Expr>() else {
abort!(value_type_attr, "Expected expression in #[value_type(...)]");
};
let value_expr_toks = normalize_value_expr(&value_type_expr);
(ident.clone(), value_expr_toks)
})
.collect::<Vec<_>>();
// `<*>::new` impl of `StructField` and `PhantomData` for `StructFields::new`
let phantoms_create = phantom_info
.iter()
.map(|id| quote! { #id: ::std::marker::PhantomData })
.collect::<Vec<_>>();
let fields_create = field_info
.iter()
.map(|(id, ty)| {
let id_lit = LitStr::new(&id.to_string(), id.span());
quote! {
#id: ::nac3core::codegen::types::structure::StructField::create(
&mut counter,
#id_lit,
#ty,
)
}
})
.collect::<Vec<_>>();
// `.into()` impl of `StructField` for `StructFields::to_vec`
let fields_into =
field_info.iter().map(|(id, _)| quote! { self.#id.into() }).collect::<Vec<_>>();
let impl_block = quote! {
impl<'ctx> ::nac3core::codegen::types::structure::StructFields<'ctx> for #ident<'ctx> {
fn new(ctx: impl ::nac3core::inkwell::context::AsContextRef<'ctx>, llvm_usize: ::nac3core::inkwell::types::IntType<'ctx>) -> Self {
let ctx = unsafe { ::nac3core::inkwell::context::ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = ::nac3core::codegen::types::structure::FieldIndexCounter::default();
#ident {
#(#fields_create),*
#(#phantoms_create),*
}
}
fn to_vec(&self) -> ::std::vec::Vec<(&'static str, ::nac3core::inkwell::types::BasicTypeEnum<'ctx>)> {
vec![
#(#fields_into),*
]
}
}
};
impl_block.into()
}

View File

@ -0,0 +1,9 @@
use nac3core_derive::StructFields;
use std::marker::PhantomData;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct EmptyValue<'ctx> {
_phantom: PhantomData<&'ctx ()>,
}
fn main() {}

View File

@ -0,0 +1,20 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayValue<'ctx> {
#[value_type(usize)]
ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
data: StructField<'ctx, PointerValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(context.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(ctx.i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,18 @@
use nac3core::{
codegen::types::structure::StructField,
inkwell::{
values::{IntValue, PointerValue},
AddressSpace,
},
};
use nac3core_derive::StructFields;
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceValue<'ctx> {
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
ptr: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(size_t)]
len: StructField<'ctx, IntValue<'ctx>>,
}
fn main() {}

View File

@ -0,0 +1,10 @@
#[test]
fn test_parse_empty() {
let t = trybuild::TestCases::new();
t.pass("tests/structfields_empty.rs");
t.pass("tests/structfields_slice.rs");
t.pass("tests/structfields_slice_ctx.rs");
t.pass("tests/structfields_slice_context.rs");
t.pass("tests/structfields_slice_sizet.rs");
t.pass("tests/structfields_ndarray.rs");
}

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,20 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::DefinitionId, toplevel::DefinitionId,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{
into_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier,
},
}, },
}; };
use nac3parser::ast::StrRef;
use std::collections::HashMap;
pub struct ConcreteTypeStore { pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>, store: Vec<ConcreteTypeEnum>,
} }
@ -22,6 +27,7 @@ pub struct ConcreteFuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: ConcreteType, pub ty: ConcreteType,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -43,14 +49,12 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive), TPrimitive(Primitive),
TTuple { TTuple {
ty: Vec<ConcreteType>, ty: Vec<ConcreteType>,
}, is_vararg_ctx: bool,
TList {
ty: ConcreteType,
}, },
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>, fields: HashMap<StrRef, (ConcreteType, bool)>,
params: HashMap<u32, ConcreteType>, params: IndexMap<TypeVarId, ConcreteType>,
}, },
TVirtual { TVirtual {
ty: ConcreteType, ty: ConcreteType,
@ -58,11 +62,10 @@ pub enum ConcreteTypeEnum {
TFunc { TFunc {
args: Vec<ConcreteFuncArg>, args: Vec<ConcreteFuncArg>,
ret: ConcreteType, ret: ConcreteType,
vars: HashMap<u32, ConcreteType>, vars: HashMap<TypeVarId, ConcreteType>,
}, },
TConstant { TLiteral {
value: SymbolValue, values: Vec<SymbolValue>,
ty: ConcreteType,
}, },
} }
@ -103,8 +106,16 @@ impl ConcreteTypeStore {
.iter() .iter()
.map(|arg| ConcreteFuncArg { .map(|arg| ConcreteFuncArg {
name: arg.name, name: arg.name,
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache), ty: if arg.is_vararg {
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect(), .collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache), ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -159,14 +170,12 @@ impl ConcreteTypeStore {
cache.insert(ty, None); cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple { TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache)) .map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(), .collect(),
}, is_vararg_ctx: *is_vararg_ctx,
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
}, },
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
@ -202,10 +211,9 @@ impl ConcreteTypeStore {
TypeEnum::TFunc(signature) => { TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache) self.from_signature(unifier, primitives, signature, cache)
} }
TypeEnum::TConstant { value, ty, .. } => ConcreteTypeEnum::TConstant { TypeEnum::TLiteral { values, .. } => {
value: value.clone(), ConcreteTypeEnum::TLiteral { values: values.clone() }
ty: self.from_unifier_type(unifier, primitives, *ty, cache), }
},
_ => unreachable!("{:?}", ty_enum.get_type_name()), _ => unreachable!("{:?}", ty_enum.get_type_name()),
}; };
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() { let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
@ -231,7 +239,7 @@ impl ConcreteTypeStore {
return if let Some(ty) = ty { return if let Some(ty) = ty {
*ty *ty
} else { } else {
*ty = Some(unifier.get_dummy_var().0); *ty = Some(unifier.get_dummy_var().ty);
ty.unwrap() ty.unwrap()
}; };
} }
@ -253,15 +261,13 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty); *cache.get_mut(&cty).unwrap() = Some(ty);
return ty; return ty;
} }
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple { ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache)) .map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
ConcreteTypeEnum::TList { ty } => {
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TVirtual { ty } => { ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
} }
@ -273,10 +279,10 @@ impl ConcreteTypeStore {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
}) })
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: params params: into_var_map(params.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<HashMap<_, _>>(), })),
}, },
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args args: args
@ -285,18 +291,17 @@ impl ConcreteTypeStore {
name: arg.name, name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache), ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: false,
}) })
.collect(), .collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: vars vars: into_var_map(vars.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<HashMap<_, _>>(), })),
}), }),
ConcreteTypeEnum::TConstant { value, ty } => TypeEnum::TConstant { ConcreteTypeEnum::TLiteral { values, .. } => {
value: value.clone(), TypeEnum::TLiteral { values: values.clone(), loc: None }
ty: self.to_unifier_type(unifier, primitives, *ty, cache),
loc: None,
} }
}; };
let result = unifier.add_ty(result); let result = unifier.add_ty(result);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,193 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use itertools::Either;
use super::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`
///
/// Arguments:
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
///
/// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
/// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue`
///
macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
};
("binary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
};
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
pub fn $fn_name<'ctx>(
ctx: &CodeGenContext<'ctx, '_>
$(,$args: FloatValue<'ctx>)*,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in [$($attributes),*] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
};
}
generate_extern_fn!("unary", call_tan, "tan");
generate_extern_fn!("unary", call_asin, "asin");
generate_extern_fn!("unary", call_acos, "acos");
generate_extern_fn!("unary", call_atan, "atan");
generate_extern_fn!("unary", call_sinh, "sinh");
generate_extern_fn!("unary", call_cosh, "cosh");
generate_extern_fn!("unary", call_tanh, "tanh");
generate_extern_fn!("unary", call_asinh, "asinh");
generate_extern_fn!("unary", call_acosh, "acosh");
generate_extern_fn!("unary", call_atanh, "atanh");
generate_extern_fn!("unary", call_expm1, "expm1");
generate_extern_fn!(
"unary",
call_cbrt,
"cbrt",
"mustprogress",
"nofree",
"nosync",
"nounwind",
"readonly",
"willreturn"
);
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
generate_extern_fn!("binary", call_atan2, "atan2");
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "ldexp";
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);

View File

@ -1,16 +1,18 @@
use crate::{
codegen::{expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
use inkwell::{ use inkwell::{
context::Context, context::Context,
types::{BasicTypeEnum, IntType}, types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
}; };
use nac3parser::ast::{Expr, Stmt, StrRef}; use nac3parser::ast::{Expr, Stmt, StrRef};
use super::{bool_to_i1, bool_to_i8, expr::*, stmt::*, values::ArraySliceValue, CodeGenContext};
use crate::{
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
};
pub trait CodeGenerator { pub trait CodeGenerator {
/// Return the module name for the code generator. /// Return the module name for the code generator.
fn get_name(&self) -> &str; fn get_name(&self) -> &str;
@ -57,6 +59,7 @@ pub trait CodeGenerator {
/// - fun: Function signature, definition ID and the substitution key. /// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the /// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method. /// function is a class method.
///
/// Note that this function should check if the function is generated in another thread (due to /// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example. /// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx>( fn gen_func_instance<'ctx>(
@ -92,6 +95,18 @@ pub trait CodeGenerator {
gen_var(ctx, ty, name) gen_var(ctx, ty, name)
} }
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
}
/// Return a pointer pointing to the target of the expression. /// Return a pointer pointing to the target of the expression.
fn gen_store_target<'ctx>( fn gen_store_target<'ctx>(
&mut self, &mut self,
@ -111,11 +126,45 @@ pub trait CodeGenerator {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> ) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
gen_assign(self, ctx, target, value) gen_assign(self, ctx, target, value, value_ty)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
} }
/// Generate code for a while expression. /// Generate code for a while expression.
@ -131,8 +180,8 @@ pub trait CodeGenerator {
gen_while(self, ctx, stmt) gen_while(self, ctx, stmt)
} }
/// Generate code for a while expression. /// Generate code for a for expression.
/// Return true if the while loop must early return /// Return true if the for loop must early return
fn gen_for( fn gen_for(
&mut self, &mut self,
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
@ -198,7 +247,7 @@ pub trait CodeGenerator {
fn bool_to_i1<'ctx>( fn bool_to_i1<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value) bool_to_i1(&ctx.builder, bool_value)
} }
@ -207,7 +256,7 @@ pub trait CodeGenerator {
fn bool_to_i8<'ctx>( fn bool_to_i8<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value) bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
} }
@ -227,7 +276,6 @@ impl DefaultCodeGenerator {
} }
impl CodeGenerator for DefaultCodeGenerator { impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`]. /// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str { fn get_name(&self) -> &str {
&self.name &self.name

View File

@ -1,199 +0,0 @@
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
T base, \
T exp \
) { \
T res = (T)1; \
/* repeated squaring method */ \
do { \
if (exp & 1) res *= base; /* for n odd */ \
exp >>= 1; \
base *= base; \
} while (exp); \
return res; \
} \
DEF_INT_EXP(int32_t)
DEF_INT_EXP(int64_t)
DEF_INT_EXP(uint32_t)
DEF_INT_EXP(uint64_t)
int32_t __nac3_slice_index_bound(int32_t i, const int32_t len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
int32_t __nac3_range_slice_len(const int32_t start, const int32_t end, const int32_t step) {
int32_t diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
int32_t __nac3_list_slice_assign_var_size(
int32_t dest_start,
int32_t dest_end,
int32_t dest_step,
uint8_t *dest_arr,
int32_t dest_arr_len,
int32_t src_start,
int32_t src_end,
int32_t src_step,
uint8_t *src_arr,
int32_t src_arr_len,
const int32_t size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const int32_t src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const int32_t dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
MAX(dest_start, dest_end) < MIN(src_start, src_end)
|| MAX(src_start, src_end) < MIN(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
int32_t src_ind = src_start;
int32_t dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}

View File

@ -0,0 +1,162 @@
use inkwell::{
types::BasicTypeEnum,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use super::calculate_len_for_slice_range;
use crate::codegen::{
macros::codegen_unreachable,
values::{ArrayLikeValue, ListValue},
CodeGenContext, CodeGenerator,
};
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr =
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr =
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let dest_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => codegen_unreachable!(ctx),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
};
// update length
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb);
}

View File

@ -0,0 +1,152 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
IntPredicate,
};
use itertools::Either;
use crate::codegen::{
macros::codegen_unreachable,
{CodeGenContext, CodeGenerator},
};
// repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>,
exp: IntValue<'ctx>,
signed: bool,
) -> IntValue<'ctx> {
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) {
(32, 32, true) => "__nac3_int_exp_int32_t",
(64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t",
_ => codegen_unreachable!(ctx),
};
let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,19 +1,29 @@
use crate::typecheck::typedef::Type;
use super::{CodeGenContext, CodeGenerator};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::Module, module::Module,
types::BasicTypeEnum, values::{BasicValue, BasicValueEnum, IntValue},
values::{FloatValue, IntValue, PointerValue}, IntPredicate,
AddressSpace, IntPredicate,
}; };
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
use super::{CodeGenContext, CodeGenerator};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
pub use list::*;
pub use math::*;
pub use range::*;
pub use slice::*;
mod list;
mod math;
pub mod ndarray;
mod range;
mod slice;
#[must_use] #[must_use]
pub fn load_irrt(ctx: &Context) -> Module { pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
let bitcode_buf = MemoryBuffer::create_from_memory_range( let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer", "irrt_bitcode_buffer",
@ -29,87 +39,47 @@ pub fn load_irrt(ctx: &Context) -> Module {
let function = irrt_mod.get_function(symbol).unwrap(); let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0)); function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
} }
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
irrt_mod irrt_mod
} }
// repeated squaring method adapted from GNU Scientific Library: /// Returns the name of a function which contains variants for 32-bit and 64-bit `size_t`.
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c ///
pub fn integer_power<'ctx>( /// - When [`TypeContext::size_type`] is 32-bits, the function name is `fn_name}`.
generator: &mut dyn CodeGenerator, /// - When [`TypeContext::size_type`] is 64-bits, the function name is `{fn_name}64`.
ctx: &mut CodeGenContext<'ctx, '_>, #[must_use]
base: IntValue<'ctx>, pub fn get_usize_dependent_function_name<G: CodeGenerator + ?Sized>(
exp: IntValue<'ctx>, generator: &G,
signed: bool, ctx: &CodeGenContext<'_, '_>,
) -> IntValue<'ctx> { name: &str,
let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { ) -> String {
(32, 32, true) => "__nac3_int_exp_int32_t", let mut name = name.to_owned();
(64, 64, true) => "__nac3_int_exp_int64_t", match generator.get_size_type(ctx.ctx).get_bit_width() {
(32, 32, false) => "__nac3_int_exp_uint32_t", 32 => {}
(64, 64, false) => "__nac3_int_exp_uint64_t", 64 => name.push_str("64"),
_ => unreachable!(), bit_width => {
}; panic!("Unsupported int type bit width {bit_width}, must be either 32-bits or 64-bits")
let base_type = base.get_type(); }
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { }
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); name
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx.builder.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
);
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.try_as_basic_value()
.unwrap_left()
.into_int_value()
}
pub fn calculate_len_for_slice_range<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
);
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.try_as_basic_value()
.left()
.unwrap()
.into_int_value()
} }
/// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// NOTE: the output value of the end index of this function should be compared ***inclusively***,
@ -158,13 +128,12 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
step: &Option<Box<Expr<Option<Type>>>>, step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G, generator: &mut G,
list: PointerValue<'ctx>, length: IntValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> { ) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let length = ctx.build_gep_and_load(list, &[zero, one], Some("length")).into_int_value(); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32");
Ok(Some(match (start, end, step) { Ok(Some(match (start, end, step) {
(s, e, None) => ( (s, e, None) => (
if let Some(s) = s.as_ref() { if let Some(s) = s.as_ref() {
@ -184,7 +153,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
} else { } else {
length length
}; };
ctx.builder.build_int_sub(e, one, "final_end") ctx.builder.build_int_sub(e, one, "final_end").unwrap()
}, },
one, one,
), ),
@ -192,15 +161,18 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
let step = if let Some(v) = generator.gen_expr(ctx, step)? { let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value() v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else { } else {
return Ok(None) return Ok(None);
}; };
// assert step != 0, throw exception if not // assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare( let not_zero = ctx
IntPredicate::NE, .builder
step, .build_int_compare(
step.get_type().const_zero(), IntPredicate::NE,
"range_step_ne", step,
); step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
not_zero, not_zero,
@ -209,340 +181,69 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1"); let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg"); let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
( (
match s { match s {
Some(s) => { Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else { let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None) return Ok(None);
}; };
ctx.builder ctx.builder
.build_select( .build_select(
ctx.builder.build_and( ctx.builder
ctx.builder.build_int_compare( .build_and(
IntPredicate::EQ, ctx.builder
s, .build_int_compare(
length, IntPredicate::EQ,
"s_eq_len", s,
), length,
neg, "s_eq_len",
"should_minus_one", )
), .unwrap(),
ctx.builder.build_int_sub(s, one, "s_min"), neg,
"should_minus_one",
)
.unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
s, s,
"final_start", "final_start",
) )
.into_int_value() .map(BasicValueEnum::into_int_value)
.unwrap()
} }
None => ctx.builder.build_select(neg, len_id, zero, "stt").into_int_value(), None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
}, },
match e { match e {
Some(e) => { Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else { let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None) return Ok(None);
}; };
ctx.builder ctx.builder
.build_select( .build_select(
neg, neg,
ctx.builder.build_int_add(e, one, "end_add_one"), ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one"), ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
"final_end", "final_end",
) )
.into_int_value() .map(BasicValueEnum::into_int_value)
.unwrap()
} }
None => ctx.builder.build_select(neg, zero, len_id, "end").into_int_value(), None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
}, },
step, step,
) )
} }
})) }))
} }
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None)
};
Ok(Some(ctx
.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.try_as_basic_value()
.left()
.unwrap()
.into_int_value()))
}
/// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
dest_arr: PointerValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: PointerValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) {
let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = {
let ty_vec = vec![
int32.into(), // dest start idx
int32.into(), // dest end idx
int32.into(), // dest step
elem_ptr_type.into(), // dest arr ptr
int32.into(), // dest arr len
int32.into(), // src start idx
int32.into(), // src end idx
int32.into(), // src step
elem_ptr_type.into(), // src arr ptr
int32.into(), // src arr len
int32.into(), // size
];
ctx.module.get_function(fun_symbol).unwrap_or_else(|| {
let fn_t = int32.fn_type(ty_vec.as_slice(), false);
ctx.module.add_function(fun_symbol, fn_t, None)
})
};
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero], Some("dest.addr"));
let dest_arr_ptr = ctx.builder.build_pointer_cast(
dest_arr_ptr.into_pointer_value(),
elem_ptr_type,
"dest_arr_ptr_cast",
);
let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one], Some("dest.len")).into_int_value();
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32");
let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero], Some("src.addr"));
let src_arr_ptr = ctx.builder.build_pointer_cast(
src_arr_ptr.into_pointer_value(),
elem_ptr_type,
"src_arr_ptr_cast",
);
let src_len = ctx.build_gep_and_load(src_arr, &[zero, one], Some("src.len")).into_int_value();
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32");
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
src_idx.2,
zero,
"is_neg",
),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one"),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one"),
"final_e",
)
.into_int_value();
let dest_end = ctx.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
dest_idx.2,
zero,
"is_neg",
),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one"),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one"),
"final_e",
)
.into_int_value();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx.builder.build_int_compare(
IntPredicate::EQ,
src_slice_len,
dest_slice_len,
"slice_src_eq_dest",
);
let src_slt_dest = ctx.builder.build_int_compare(
IntPredicate::SLT,
src_slice_len,
dest_slice_len,
"slice_src_slt_dest",
);
let dest_step_eq_one = ctx.builder.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
);
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1");
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond");
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = {
let args = vec![
dest_idx.0.into(), // dest start idx
dest_idx.1.into(), // dest end idx
dest_idx.2.into(), // dest step
dest_arr_ptr.into(), // dest arr ptr
dest_len.into(), // dest arr len
src_idx.0.into(), // src start idx
src_idx.1.into(), // src end idx
src_idx.2.into(), // src step
src_arr_ptr.into(), // src arr ptr
src_len.into(), // src arr len
{
let s = match ty {
BasicTypeEnum::FloatType(t) => t.size_of(),
BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => unreachable!(),
};
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size")
}
.into(),
];
ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.try_as_basic_value()
.unwrap_left()
.into_int_value()
};
// update length
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update");
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb);
ctx.builder.position_at_end(update_bb);
let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") };
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len");
ctx.builder.build_store(dest_len_ptr, new_len);
ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(cont_bb);
}
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.try_as_basic_value()
.unwrap_left()
.into_int_value();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.try_as_basic_value()
.unwrap_left()
.into_int_value();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.try_as_basic_value()
.unwrap_left()
.into_float_value()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.try_as_basic_value()
.unwrap_left()
.into_float_value()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.try_as_basic_value()
.unwrap_left()
.into_float_value()
}

View File

@ -0,0 +1,250 @@
use inkwell::{
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
use crate::codegen::{
expr::{create_and_call_function, infer_and_call_function},
irrt::get_usize_dependent_function_name,
types::ProxyType,
values::{ndarray::NDArrayValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_util_assert_shape_no_negative<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndims: IntValue<'ctx>,
shape: PointerValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let name = get_usize_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_shape_no_negative",
);
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_usize.into(), ndims.into()), (llvm_pusize.into(), shape.into())],
None,
None,
);
}
pub fn call_nac3_ndarray_util_assert_output_shape_same<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray_ndims: IntValue<'ctx>,
ndarray_shape: PointerValue<'ctx>,
output_ndims: IntValue<'ctx>,
output_shape: IntValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let name = get_usize_dependent_function_name(
generator,
ctx,
"__nac3_ndarray_util_assert_output_shape_same",
);
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[
(llvm_usize.into(), ndarray_ndims.into()),
(llvm_pusize.into(), ndarray_shape.into()),
(llvm_usize.into(), output_ndims.into()),
(llvm_pusize.into(), output_shape.into()),
],
None,
None,
);
}
pub fn call_nac3_ndarray_size<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_size");
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("size"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_nbytes<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_nbytes");
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("nbytes"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_len");
create_and_call_function(
ctx,
&name,
Some(llvm_usize.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("len"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_is_c_contiguous<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> IntValue<'ctx> {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_is_c_contiguous");
create_and_call_function(
ctx,
&name,
Some(llvm_i1.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
Some("is_c_contiguous"),
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_ndarray_get_nth_pelement<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
index: IntValue<'ctx>,
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_ndarray = ndarray.get_type().as_base_type();
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_nth_pelement");
create_and_call_function(
ctx,
&name,
Some(llvm_pi8.into()),
&[(llvm_ndarray.into(), ndarray.as_base_value().into()), (llvm_usize.into(), index.into())],
Some("pelement"),
None,
)
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
pub fn call_nac3_ndarray_get_pelement_by_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: PointerValue<'ctx>,
) -> PointerValue<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let llvm_ndarray = ndarray.get_type().as_base_type();
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_get_pelement_by_indices");
create_and_call_function(
ctx,
&name,
Some(llvm_pi8.into()),
&[
(llvm_ndarray.into(), ndarray.as_base_value().into()),
(llvm_pusize.into(), indices.into()),
],
Some("pelement"),
None,
)
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
pub fn call_nac3_ndarray_set_strides_by_shape<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) {
let llvm_ndarray = ndarray.get_type().as_base_type();
let name =
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_set_strides_by_shape");
create_and_call_function(
ctx,
&name,
None,
&[(llvm_ndarray.into(), ndarray.as_base_value().into())],
None,
None,
);
}
pub fn call_nac3_ndarray_copy_data<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>,
) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_copy_data");
infer_and_call_function(
ctx,
&name,
None,
&[src_ndarray.as_base_value().into(), dst_ndarray.as_base_value().into()],
None,
None,
);
}

View File

@ -0,0 +1,29 @@
use crate::codegen::{
expr::infer_and_call_function,
irrt::get_usize_dependent_function_name,
values::{ndarray::NDArrayValue, ArrayLikeValue, ArraySliceValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_ndarray_index<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
indices: ArraySliceValue<'ctx>,
src_ndarray: NDArrayValue<'ctx>,
dst_ndarray: NDArrayValue<'ctx>,
) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_index");
infer_and_call_function(
ctx,
&name,
None,
&[
indices.size(ctx, generator).into(),
indices.base_ptr(ctx, generator).into(),
src_ndarray.as_base_value().into(),
dst_ndarray.as_base_value().into(),
],
None,
None,
);
}

View File

@ -0,0 +1,70 @@
use inkwell::{
values::{BasicValueEnum, IntValue},
AddressSpace,
};
use crate::codegen::{
expr::{create_and_call_function, infer_and_call_function},
irrt::get_usize_dependent_function_name,
types::ProxyType,
values::{
ndarray::{NDArrayValue, NDIterValue},
ArrayLikeValue, ArraySliceValue, ProxyValue,
},
CodeGenContext, CodeGenerator,
};
pub fn call_nac3_nditer_initialize<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_initialize");
create_and_call_function(
ctx,
&name,
None,
&[
(iter.get_type().as_base_type().into(), iter.as_base_value().into()),
(ndarray.get_type().as_base_type().into(), ndarray.as_base_value().into()),
(llvm_pusize.into(), indices.base_ptr(ctx, generator).into()),
],
None,
None,
);
}
pub fn call_nac3_nditer_has_element<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
) -> IntValue<'ctx> {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_has_element");
infer_and_call_function(
ctx,
&name,
Some(ctx.ctx.bool_type().into()),
&[iter.as_base_value().into()],
None,
None,
)
.map(BasicValueEnum::into_int_value)
.unwrap()
}
pub fn call_nac3_nditer_next<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
iter: NDIterValue<'ctx>,
) {
let name = get_usize_dependent_function_name(generator, ctx, "__nac3_nditer_next");
infer_and_call_function(ctx, &name, None, &[iter.as_base_value().into()], None, None);
}

View File

@ -0,0 +1,391 @@
use inkwell::{
types::IntType,
values::{BasicValueEnum, CallSiteValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
use crate::codegen::{
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
values::{
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
CodeGenContext, CodeGenerator,
};
pub use basic::*;
pub use indexing::*;
pub use iter::*;
mod basic;
mod indexing;
mod iter;
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.shape();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
rhs.shape().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.shape().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.shape().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.shape().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -0,0 +1,42 @@
use inkwell::{
values::{BasicValueEnum, CallSiteValue, IntValue},
IntPredicate,
};
use itertools::Either;
use crate::codegen::{CodeGenContext, CodeGenerator};
pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
end: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
const SYMBOL: &str = "__nac3_range_slice_len";
let len_func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -0,0 +1,39 @@
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
};
/// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
length: IntValue<'ctx>,
) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type();
let fn_t = i32_t.fn_type(&[i32_t.into(), i32_t.into()], false);
ctx.module.add_function(SYMBOL, fn_t, None)
});
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None);
};
Ok(Some(
ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(),
))
}

View File

@ -0,0 +1,410 @@
use inkwell::{
context::Context,
intrinsics::Intrinsic,
types::{AnyTypeEnum::IntType, FloatType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
AddressSpace,
};
use itertools::Either;
use super::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16";
}
if ft == ctx.f32_type() {
return "f32";
}
if ft == ctx.f64_type() {
return "f64";
}
if ft == ctx.f128_type() {
return "f128";
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80";
}
if ft == ctx.ppc_f128_type() {
return "ppcf128";
}
unreachable!()
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
const FN_NAME: &str = "llvm.stacksave";
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
///
/// - `ptr`: The pointer storing the address to restore the stack to.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
/*
SEE https://github.com/TheDan64/inkwell/issues/496
We want `llvm.stackrestore`, but the following would generate `llvm.stackrestore.p0i8`.
```ignore
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
```
Temp workaround by manually declaring the intrinsic with the correct function name instead.
*/
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
/// * `len` - The number of bytes to copy.
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
pub fn call_memcpy<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
const FN_NAME: &str = "llvm.memcpy";
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
dest.get_type().get_element_type().into_int_type().get_bit_width(),
src.get_type().get_element_type().into_int_type().get_bit_width(),
);
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
let llvm_dest_t = dest.get_type();
let llvm_src_t = src.get_type();
let llvm_len_t = len.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(
&ctx.module,
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
)
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
.unwrap();
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
pub fn call_memcpy_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
/// Moreover, `len` now refers to the number of elements to copy (rather than number of bytes to
/// copy).
pub fn call_memcpy_generic_array<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_sizeof_expr_t = llvm_i8.size_of().get_type();
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let len = ctx.builder.build_int_z_extend_or_bit_cast(len, llvm_sizeof_expr_t, "").unwrap();
let len = ctx.builder.build_int_mul(len, src_elem_t.size_of().unwrap(), "").unwrap();
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
///
/// Arguments:
/// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
/// Use `BasicValueEnum::into_int_value` for Integer return type and
/// `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body {
($ctx:ident, $name:ident, $llvm_name:literal, $map_fn:expr, $llvm_ty:ident $(,$val:ident)*) => {{
const FN_NAME: &str = concat!("llvm.", $llvm_name);
let intrinsic_fn = Intrinsic::find(FN_NAME).and_then(|intrinsic| intrinsic.get_declaration(&$ctx.module, &[$llvm_ty.into()])).unwrap();
$ctx.builder.build_call(intrinsic_fn, &[$($val.into()),*], $name.unwrap_or_default()).map(CallSiteValue::try_as_basic_value).map(|v| v.map_left($map_fn)).map(Either::unwrap_left).unwrap()
}};
}
/// Macro to generate the llvm intrinsic function using [`generate_llvm_intrinsic_fn_body`].
///
/// Arguments:
/// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn {
("float", $fn_name:ident, $llvm_name:literal, $val:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_ty = $val.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val)
}
};
("float", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: FloatValue<'ctx>,
$val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!($val1.get_type(), $val2.get_type());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val1, $val2)
}
};
("int", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: IntValue<'ctx>,
$val2: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!($val1.get_type().get_bit_width(), $val2.get_type().get_bit_width());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_int_value, llvm_ty, $val1, $val2)
}
};
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let src_type = src.get_type();
generate_llvm_intrinsic_fn_body!(
ctx,
name,
"abs",
BasicValueEnum::into_int_value,
src_type,
src,
is_int_min_poison
)
}
generate_llvm_intrinsic_fn!("int", call_int_smax, "smax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_smin, "smin", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umax, "umax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umin, "umin", a, b);
generate_llvm_intrinsic_fn!("int", call_expect, "expect", val, expected_val);
generate_llvm_intrinsic_fn!("float", call_float_sqrt, "sqrt", val);
generate_llvm_intrinsic_fn!("float", call_float_sin, "sin", val);
generate_llvm_intrinsic_fn!("float", call_float_cos, "cos", val);
generate_llvm_intrinsic_fn!("float", call_float_pow, "pow", val, power);
generate_llvm_intrinsic_fn!("float", call_float_exp, "exp", val);
generate_llvm_intrinsic_fn!("float", call_float_exp2, "exp2", val);
generate_llvm_intrinsic_fn!("float", call_float_log, "log", val);
generate_llvm_intrinsic_fn!("float", call_float_log10, "log10", val);
generate_llvm_intrinsic_fn!("float", call_float_log2, "log2", val);
generate_llvm_intrinsic_fn!("float", call_float_fabs, "fabs", src);
generate_llvm_intrinsic_fn!("float", call_float_minnum, "minnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_maxnum, "maxnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_copysign, "copysign", mag, sgn);
generate_llvm_intrinsic_fn!("float", call_float_floor, "floor", val);
generate_llvm_intrinsic_fn!("float", call_float_ceil, "ceil", val);
generate_llvm_intrinsic_fn!("float", call_float_round, "round", val);
generate_llvm_intrinsic_fn!("float", call_float_rint, "rint", val);
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.powi";
let llvm_val_t = val.get_type();
let llvm_power_t = power.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the [`llvm.ctpop`](https://llvm.org/docs/LangRef.html#llvm-ctpop-intrinsic) intrinsic.
pub fn call_int_ctpop<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
const FN_NAME: &str = "llvm.ctpop";
let llvm_src_t = src.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_src_t.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[src.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,50 +1,79 @@
use crate::{ use std::{
symbol_resolver::{StaticValue, SymbolResolver}, collections::{HashMap, HashSet},
toplevel::{TopLevelContext, TopLevelDef}, sync::{
typecheck::{ atomic::{AtomicBool, Ordering},
type_inferencer::{CodeLocation, PrimitiveStore}, Arc,
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
}, },
thread,
}; };
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{ use inkwell::{
AddressSpace,
IntPredicate,
OptimizationLevel,
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
builder::Builder, builder::Builder,
context::Context, context::Context,
debug_info::{
AsDIScope, DICompileUnit, DIFlagsConstants, DIScope, DISubprogram, DebugInfoBuilder,
},
module::Module, module::Module,
passes::PassBuilderOptions, passes::PassBuilderOptions,
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
debug_info::{ AddressSpace, IntPredicate, OptimizationLevel,
DebugInfoBuilder, DICompileUnit, DISubprogram, AsDIScope, DIFlagsConstants, DIScope
},
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::{Stmt, StrRef, Location};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::thread;
use nac3parser::ast::{Location, Stmt, StrRef};
use crate::{
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{
helper::{extract_ndims, PrimDef},
numpy::unpack_ndarray_var_tys,
TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
};
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
pub use generator::{CodeGenerator, DefaultCodeGenerator};
use types::{ndarray::NDArrayType, ListType, ProxyType, RangeType};
pub mod builtin_fns;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;
pub mod extern_fns;
mod generator; mod generator;
pub mod irrt; pub mod irrt;
pub mod llvm_intrinsics;
pub mod numpy;
pub mod stmt; pub mod stmt;
pub mod types;
pub mod values;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore}; mod macros {
pub use generator::{CodeGenerator, DefaultCodeGenerator}; /// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
/// its first argument to provide Python source information to indicate the codegen location
/// causing the assertion.
macro_rules! codegen_unreachable {
($ctx:expr $(,)?) => {
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
};
($ctx:expr, $($arg:tt)*) => {
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
};
}
pub(crate) use codegen_unreachable;
}
#[derive(Default)] #[derive(Default)]
pub struct StaticValueStore { pub struct StaticValueStore {
@ -64,6 +93,16 @@ pub struct CodeGenLLVMOptions {
pub target: CodeGenTargetMachineOptions, pub target: CodeGenTargetMachineOptions,
} }
impl CodeGenLLVMOptions {
/// Creates a [`TargetMachine`] using the target options specified by this struct.
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(&self) -> Option<TargetMachine> {
self.target.create_target_machine(self.opt_level)
}
}
/// Additional options for code generation for the target machine. /// Additional options for code generation for the target machine.
#[derive(Clone, Debug, Eq, PartialEq)] #[derive(Clone, Debug, Eq, PartialEq)]
pub struct CodeGenTargetMachineOptions { pub struct CodeGenTargetMachineOptions {
@ -80,7 +119,6 @@ pub struct CodeGenTargetMachineOptions {
} }
impl CodeGenTargetMachineOptions { impl CodeGenTargetMachineOptions {
/// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine. /// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine.
/// Other options are set to defaults. /// Other options are set to defaults.
#[must_use] #[must_use]
@ -109,13 +147,11 @@ impl CodeGenTargetMachineOptions {
/// ///
/// See [`Target::create_target_machine`]. /// See [`Target::create_target_machine`].
#[must_use] #[must_use]
pub fn create_target_machine( pub fn create_target_machine(&self, level: OptimizationLevel) -> Option<TargetMachine> {
&self,
level: OptimizationLevel,
) -> Option<TargetMachine> {
let triple = TargetTriple::create(self.triple.as_str()); let triple = TargetTriple::create(self.triple.as_str());
let target = Target::from_triple(&triple) let target = Target::from_triple(&triple).unwrap_or_else(|_| {
.unwrap_or_else(|_| panic!("could not create target from target triple {}", self.triple)); panic!("could not create target from target triple {}", self.triple)
});
target.create_target_machine( target.create_target_machine(
&triple, &triple,
@ -123,7 +159,7 @@ impl CodeGenTargetMachineOptions {
self.features.as_str(), self.features.as_str(),
level, level,
self.reloc_mode, self.reloc_mode,
self.code_model self.code_model,
) )
} }
} }
@ -134,24 +170,23 @@ pub struct CodeGenContext<'ctx, 'a> {
/// The [Builder] instance for creating LLVM IR statements. /// The [Builder] instance for creating LLVM IR statements.
pub builder: Builder<'ctx>, pub builder: Builder<'ctx>,
/// The [DebugInfoBuilder], [compilation unit information][DICompileUnit], and /// The [`DebugInfoBuilder`], [compilation unit information][DICompileUnit], and
/// [scope information][DIScope] of this context. /// [scope information][DIScope] of this context.
pub debug_info: (DebugInfoBuilder<'ctx>, DICompileUnit<'ctx>, DIScope<'ctx>), pub debug_info: (DebugInfoBuilder<'ctx>, DICompileUnit<'ctx>, DIScope<'ctx>),
/// The module for which [this context][CodeGenContext] is generating into. /// The module for which [this context][CodeGenContext] is generating into.
pub module: Module<'ctx>, pub module: Module<'ctx>,
/// The [TopLevelContext] associated with [this context][CodeGenContext]. /// The [`TopLevelContext`] associated with [this context][CodeGenContext].
pub top_level: &'a TopLevelContext, pub top_level: &'a TopLevelContext,
pub unifier: Unifier, pub unifier: Unifier,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>, pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub static_value_store: Arc<Mutex<StaticValueStore>>, pub static_value_store: Arc<Mutex<StaticValueStore>>,
/// A [HashMap] containing the mapping between the names of variables currently in-scope and /// A [`HashMap`] containing the mapping between the names of variables currently in-scope and
/// its value information. /// its value information.
pub var_assignment: HashMap<StrRef, VarValue<'ctx>>, pub var_assignment: HashMap<StrRef, VarValue<'ctx>>,
///
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>, pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub calls: Arc<HashMap<CodeLocation, CallId>>, pub calls: Arc<HashMap<CodeLocation, CallId>>,
@ -160,24 +195,24 @@ pub struct CodeGenContext<'ctx, 'a> {
/// Cache for constant strings. /// Cache for constant strings.
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>, pub const_strings: HashMap<String, BasicValueEnum<'ctx>>,
/// [BasicBlock] containing all `alloca` statements for the current function. /// [`BasicBlock`] containing all `alloca` statements for the current function.
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
pub exception_val: Option<PointerValue<'ctx>>, pub exception_val: Option<PointerValue<'ctx>>,
/// The header and exit basic blocks of a loop in this context. See /// The header and exit basic blocks of a loop in this context. See
/// https://llvm.org/docs/LoopTerminology.html for explanation of these terminology. /// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology.
pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
/// The target [BasicBlock] to jump to when performing stack unwind. /// The target [`BasicBlock`] to jump to when performing stack unwind.
pub unwind_target: Option<BasicBlock<'ctx>>, pub unwind_target: Option<BasicBlock<'ctx>>,
/// The target [BasicBlock] to jump to before returning from the function. /// The target [`BasicBlock`] to jump to before returning from the function.
/// ///
/// If this field is [None] when generating a return from a function, `ret` with no argument can /// If this field is [None] when generating a return from a function, `ret` with no argument can
/// be emitted. /// be emitted.
pub return_target: Option<BasicBlock<'ctx>>, pub return_target: Option<BasicBlock<'ctx>>,
/// The [PointerValue] containing the return value of the function. /// The [`PointerValue`] containing the return value of the function.
pub return_buffer: Option<PointerValue<'ctx>>, pub return_buffer: Option<PointerValue<'ctx>>,
// outer catch clauses // outer catch clauses
@ -186,7 +221,7 @@ pub struct CodeGenContext<'ctx, 'a> {
/// Whether `sret` is needed for the first parameter of the function. /// Whether `sret` is needed for the first parameter of the function.
/// ///
/// See [need_sret]. /// See [`need_sret`].
pub need_sret: bool, pub need_sret: bool,
/// The current source location. /// The current source location.
@ -194,7 +229,6 @@ pub struct CodeGenContext<'ctx, 'a> {
} }
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
/// contains a [terminator statement][BasicBlock::get_terminator]. /// contains a [terminator statement][BasicBlock::get_terminator].
pub fn is_terminated(&self) -> bool { pub fn is_terminated(&self) -> bool {
@ -236,11 +270,10 @@ pub struct WorkerRegistry {
static_value_store: Arc<Mutex<StaticValueStore>>, static_value_store: Arc<Mutex<StaticValueStore>>,
/// LLVM-related options for code generation. /// LLVM-related options for code generation.
llvm_options: CodeGenLLVMOptions, pub llvm_options: CodeGenLLVMOptions,
} }
impl WorkerRegistry { impl WorkerRegistry {
/// Creates workers for this registry. /// Creates workers for this registry.
#[must_use] #[must_use]
pub fn create_workers<G: CodeGenerator + Send + 'static>( pub fn create_workers<G: CodeGenerator + Send + 'static>(
@ -275,9 +308,15 @@ impl WorkerRegistry {
let registry = registry.clone(); let registry = registry.clone();
let registry2 = registry.clone(); let registry2 = registry.clone();
let f = f.clone(); let f = f.clone();
let handle = thread::spawn(move || {
registry.worker_thread(generator.as_mut(), &f); let worker_thread_name =
}); format!("codegen-worker-{worker_id}", worker_id = generator.get_name());
let handle = thread::Builder::new()
.name(worker_thread_name)
.spawn(move || {
registry.worker_thread(generator.as_mut(), &f);
})
.unwrap();
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
if let Err(e) = handle.join() { if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() { if let Some(e) = e.downcast_ref::<&'static str>() {
@ -334,6 +373,10 @@ impl WorkerRegistry {
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let mut module = context.create_module(generator.get_name()); let mut module = context.create_module(generator.get_name());
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
module.add_basic_value_flag( module.add_basic_value_flag(
"Debug Info Version", "Debug Info Version",
inkwell::module::FlagBehavior::Warning, inkwell::module::FlagBehavior::Warning,
@ -357,12 +400,20 @@ impl WorkerRegistry {
errors.insert(e); errors.insert(e);
// create a new empty module just to continue codegen and collect errors // create a new empty module just to continue codegen and collect errors
module = context.create_module(&format!("{}_recover", generator.get_name())); module = context.create_module(&format!("{}_recover", generator.get_name()));
let target_machine = self.llvm_options.create_target_machine().unwrap();
module.set_data_layout(&target_machine.get_target_data().get_data_layout());
module.set_triple(&target_machine.get_triple());
} }
} }
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
self.wait_condvar.notify_all(); self.wait_condvar.notify_all();
} }
assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().sorted().join("\n----------\n")); assert!(
errors.is_empty(),
"Codegen error: {}",
errors.into_iter().sorted().join("\n----------\n")
);
let result = module.verify(); let result = module.verify();
if let Err(err) = result { if let Err(err) = result {
@ -375,13 +426,20 @@ impl WorkerRegistry {
.llvm_options .llvm_options
.target .target
.create_target_machine(self.llvm_options.opt_level) .create_target_machine(self.llvm_options.opt_level)
.unwrap_or_else(|| panic!("could not create target machine from properties {:?}", self.llvm_options.target)); .unwrap_or_else(|| {
panic!(
"could not create target machine from properties {:?}",
self.llvm_options.target
)
});
let passes = format!("default<O{}>", self.llvm_options.opt_level as u32); let passes = format!("default<O{}>", self.llvm_options.opt_level as u32);
let result = module.run_passes(passes.as_str(), &target_machine, pass_options); let result = module.run_passes(passes.as_str(), &target_machine, pass_options);
if let Err(err) = result { if let Err(err) = result {
panic!("Failed to run optimization for module `{}`: {}", panic!(
module.get_name().to_str().unwrap(), "Failed to run optimization for module `{}`: {}",
err.to_string()); module.get_name().to_str().unwrap(),
err.to_string()
);
} }
f.run(&module); f.run(&module);
@ -407,14 +465,14 @@ pub struct CodeGenTask {
/// ///
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable /// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
/// would be represented by an `i8`. /// would be represented by an `i8`.
fn get_llvm_type<'ctx>( #[allow(clippy::too_many_arguments)]
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut dyn CodeGenerator, generator: &G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
primitives: &PrimitiveStore,
ty: Type, ty: Type,
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
use TypeEnum::*; use TypeEnum::*;
@ -424,29 +482,52 @@ fn get_llvm_type<'ctx>(
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
// check to avoid treating primitives other than Option as classes // check to avoid treating non-class primitives as classes
if obj_id.0 <= 10 { if PrimDef::contains_id(*obj_id) {
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref()) return match &*unifier.get_ty_immutable(ty) {
{ TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
( get_llvm_type(
TObj { obj_id, params, .. },
TObj { obj_id: opt_id, .. },
) if *obj_id == *opt_id => {
return get_llvm_type(
ctx, ctx,
module, module,
generator, generator,
unifier, unifier,
top_level, top_level,
type_cache, type_cache,
primitives,
*params.iter().next().unwrap().1, *params.iter().next().unwrap().1,
) )
.ptr_type(AddressSpace::default()) .ptr_type(AddressSpace::default())
.into(); .into()
} }
_ => unreachable!("must be option type"),
} TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let element_type = get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
*params.iter().next().unwrap().1,
);
ListType::new(generator, ctx, element_type).as_base_type().into()
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, ndims) = unpack_ndarray_var_tys(unifier, ty);
let ndims = extract_ndims(unifier, ndims);
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type, Some(ndims)).as_base_type().into()
}
_ => unreachable!(
"LLVM type for primitive {} is missing",
unifier.stringify(ty)
),
};
} }
// a struct with fields in the order of declaration // a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read(); let top_level_defs = top_level.definitions.read();
@ -462,7 +543,7 @@ fn get_llvm_type<'ctx>(
let struct_type = ctx.opaque_struct_type(&name); let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert( type_cache.insert(
unifier.get_representative(ty), unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into() struct_type.ptr_type(AddressSpace::default()).into(),
); );
let fields = fields_list let fields = fields_list
.iter() .iter()
@ -474,7 +555,6 @@ fn get_llvm_type<'ctx>(
unifier, unifier,
top_level, top_level,
type_cache, type_cache,
primitives,
fields[&f.0].0, fields[&f.0].0,
) )
}) })
@ -482,31 +562,20 @@ fn get_llvm_type<'ctx>(
struct_type.set_body(&fields, false); struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into() struct_type.ptr_type(AddressSpace::default()).into()
}; };
return ty return ty;
} }
TTuple { ty } => { TTuple { ty, is_vararg_ctx } => {
// a struct with fields in the order present in the tuple // a struct with fields in the order present in the tuple
assert!(!is_vararg_ctx, "Tuples in vararg context must be instantiated with the correct number of arguments before calling get_llvm_type");
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
get_llvm_type( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
)
}) })
.collect_vec(); .collect_vec();
ctx.struct_type(&fields, false).into() ctx.struct_type(&fields, false).into()
} }
TList { ty } => {
// a struct with an integer and a pointer to an array
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
);
let fields = [
element_type.ptr_type(AddressSpace::default()).into(),
generator.get_size_type(ctx).into(),
];
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
}
TVirtual { .. } => unimplemented!(), TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()), _ => unreachable!("{}", ty_enum.get_type_name()),
}; };
@ -524,10 +593,11 @@ fn get_llvm_type<'ctx>(
/// ABI representation is that the in-memory representation must be at least byte-sized and must /// ABI representation is that the in-memory representation must be at least byte-sized and must
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such /// be byte-aligned for the variable to be addressable in memory, whereas there is no such
/// restriction for ABI representations. /// restriction for ABI representations.
fn get_llvm_abi_type<'ctx>( #[allow(clippy::too_many_arguments)]
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut dyn CodeGenerator, generator: &G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -536,10 +606,10 @@ fn get_llvm_abi_type<'ctx>(
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI // If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
// consistency. // consistency.
return if unifier.unioned(ty, primitives.bool) { if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into() ctx.bool_type().into()
} else { } else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, primitives, ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
} }
} }
@ -556,23 +626,62 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
match ty { match ty {
BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false, BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false,
BasicTypeEnum::FloatType(_) if maybe_large => false, BasicTypeEnum::FloatType(_) if maybe_large => false,
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => {
ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false)), ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false))
}
_ => true, _ => true,
} }
} }
need_sret_impl(ty, true) need_sret_impl(ty, true)
} }
/// Returns the [`BasicTypeEnum`] representing a `va_list` struct for variadic arguments.
fn get_llvm_valist_type<'ctx>(ctx: &'ctx Context, triple: &TargetTriple) -> BasicTypeEnum<'ctx> {
let triple = TargetMachine::normalize_triple(triple);
let triple = triple.as_str().to_str().unwrap();
let arch = triple.split('-').next().unwrap();
let llvm_pi8 = ctx.i8_type().ptr_type(AddressSpace::default());
// Referenced from parseArch() in llvm/lib/Support/Triple.cpp
match arch {
"i386" | "i486" | "i586" | "i686" | "riscv32" => {
ctx.i8_type().ptr_type(AddressSpace::default()).into()
}
"amd64" | "x86_64" | "x86_64h" => {
let llvm_i32 = ctx.i32_type();
let va_list_tag = ctx.opaque_struct_type("struct.__va_list_tag");
va_list_tag.set_body(
&[llvm_i32.into(), llvm_i32.into(), llvm_pi8.into(), llvm_pi8.into()],
false,
);
va_list_tag.into()
}
"armv7" => {
let va_list = ctx.opaque_struct_type("struct.__va_list");
va_list.set_body(&[llvm_pi8.into()], false);
va_list.into()
}
triple => {
todo!("Unsupported platform for varargs: {triple}")
}
}
}
/// Implementation for generating LLVM IR for a function. /// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> ( pub fn gen_func_impl<
'ctx,
G: CodeGenerator,
F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>,
>(
context: &'ctx Context, context: &'ctx Context,
generator: &mut G, generator: &mut G,
registry: &WorkerRegistry, registry: &WorkerRegistry,
builder: Builder<'ctx>, builder: Builder<'ctx>,
module: Module<'ctx>, module: Module<'ctx>,
task: CodeGenTask, task: CodeGenTask,
codegen_function: F codegen_function: F,
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { ) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let top_level_ctx = registry.top_level_ctx.clone(); let top_level_ctx = registry.top_level_ctx.clone();
let static_value_store = registry.static_value_store.clone(); let static_value_store = registry.static_value_store.clone();
@ -580,6 +689,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index]; let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives) (Unifier::from_shared_unifier(unifier), *primitives)
}; };
unifier.put_primitive_store(&primitives);
unifier.top_level = Some(top_level_ctx.clone()); unifier.top_level = Some(top_level_ctx.clone());
let mut cache = HashMap::new(); let mut cache = HashMap::new();
@ -613,6 +723,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
str: unifier.get_representative(primitives.str), str: unifier.get_representative(primitives.str),
exception: unifier.get_representative(primitives.exception), exception: unifier.get_representative(primitives.exception),
option: unifier.get_representative(primitives.option), option: unifier.get_representative(primitives.option),
..primitives
}; };
let mut type_cache: HashMap<_, _> = [ let mut type_cache: HashMap<_, _> = [
@ -634,10 +745,10 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
str_type.set_body(&fields, false); str_type.set_body(&fields, false);
str_type.into() str_type.into()
} }
Some(t) => t.as_basic_type_enum() Some(t) => t.as_basic_type_enum(),
} }
}), }),
(primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::default()).into()), (primitives.range, RangeType::new(context).as_base_type().into()),
(primitives.exception, { (primitives.exception, {
let name = "Exception"; let name = "Exception";
if let Some(t) = module.get_struct_type(name) { if let Some(t) = module.get_struct_type(name) {
@ -651,7 +762,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
exception.set_body(&fields, false); exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum() exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
} }
}) }),
] ]
.iter() .iter()
.copied() .copied()
@ -659,8 +770,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
// NOTE: special handling of option cannot use this type cache since it contains type var, // NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead // handled inside get_llvm_type instead
let ConcreteTypeEnum::TFunc { args, ret, .. } = let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) else {
task.store.get(task.signature) else {
unreachable!() unreachable!()
}; };
@ -670,6 +780,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
name: arg.name, name: arg.name,
ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache), ty: task.store.to_unifier_type(&mut unifier, &primitives, arg.ty, &mut cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect_vec(), .collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache), task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
@ -677,13 +788,25 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let ret_type = if unifier.unioned(ret, primitives.none) { let ret_type = if unifier.unioned(ret, primitives.none) {
None None
} else { } else {
Some(get_llvm_abi_type(context, &module, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret)) Some(get_llvm_abi_type(
context,
&module,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
ret,
))
}; };
let has_sret = ret_type.map_or(false, |ty| need_sret(ty)); let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
let mut params = args let mut params = args
.iter() .iter()
.filter(|arg| !arg.is_vararg)
.map(|arg| { .map(|arg| {
debug_assert!(!arg.is_vararg);
get_llvm_abi_type( get_llvm_abi_type(
context, context,
&module, &module,
@ -702,9 +825,12 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into()); params.insert(0, ret_type.unwrap().ptr_type(AddressSpace::default()).into());
} }
debug_assert!(matches!(args.iter().filter(|arg| arg.is_vararg).count(), 0..=1));
let vararg_arg = args.iter().find(|arg| arg.is_vararg);
let fn_type = match ret_type { let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false), Some(ret_type) if !has_sret => ret_type.fn_type(&params, vararg_arg.is_some()),
_ => context.void_type().fn_type(&params, false) _ => context.void_type().fn_type(&params, vararg_arg.is_some()),
}; };
let symbol = &task.symbol_name; let symbol = &task.symbol_name;
@ -719,18 +845,23 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
fn_val.set_personality_function(personality); fn_val.set_personality_function(personality);
} }
if has_sret { if has_sret {
fn_val.add_attribute(AttributeLoc::Param(0), fn_val.add_attribute(
context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), AttributeLoc::Param(0),
ret_type.unwrap().as_any_type_enum())); context.create_type_attribute(
Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum(),
),
);
} }
let init_bb = context.append_basic_block(fn_val, "init"); let init_bb = context.append_basic_block(fn_val, "init");
builder.position_at_end(init_bb); builder.position_at_end(init_bb);
let body_bb = context.append_basic_block(fn_val, "body"); let body_bb = context.append_basic_block(fn_val, "body");
// Store non-vararg argument values into local variables
let mut var_assignment = HashMap::new(); let mut var_assignment = HashMap::new();
let offset = u32::from(has_sret); let offset = u32::from(has_sret);
for (n, arg) in args.iter().enumerate() { for (n, arg) in args.iter().enumerate().filter(|(_, arg)| !arg.is_vararg) {
let param = fn_val.get_nth_param((n as u32) + offset).unwrap(); let param = fn_val.get_nth_param((n as u32) + offset).unwrap();
let local_type = get_llvm_type( let local_type = get_llvm_type(
context, context,
@ -739,13 +870,10 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
&mut unifier, &mut unifier,
top_level_ctx.as_ref(), top_level_ctx.as_ref(),
&mut type_cache, &mut type_cache,
&primitives,
arg.ty, arg.ty,
); );
let alloca = builder.build_alloca( let alloca =
local_type, builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())).unwrap();
&format!("{}.addr", &arg.name.to_string()),
);
// Remap boolean parameters into i8 // Remap boolean parameters into i8
let param = if local_type.is_int_type() && param.is_int_value() { let param = if local_type.is_int_type() && param.is_int_value() {
@ -756,19 +884,22 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
bool_to_i8(&builder, context, param_val) bool_to_i8(&builder, context, param_val)
} else { } else {
param_val param_val
}.into() }
.into()
} else { } else {
param param
}; };
builder.build_store(alloca, param); builder.build_store(alloca, param).unwrap();
var_assignment.insert(arg.name, (alloca, None, 0)); var_assignment.insert(arg.name, (alloca, None, 0));
} }
// TODO: Save vararg parameters as list
let return_buffer = if has_sret { let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value()) Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else { } else {
fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret")) fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret").unwrap())
}; };
let static_values = { let static_values = {
@ -780,7 +911,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
*static_val = Some(v); *static_val = Some(v);
} }
builder.build_unconditional_branch(body_bb); builder.build_unconditional_branch(body_bb).unwrap();
builder.position_at_end(body_bb); builder.position_at_end(body_bb);
let (dibuilder, compile_unit) = module.create_debug_info_builder( let (dibuilder, compile_unit) = module.create_debug_info_builder(
@ -789,11 +920,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
/* filename */ /* filename */
&task &task
.body .body
.get(0) .first()
.map_or_else( .map_or_else(|| "<nac3_internal>".to_string(), |f| f.location.file.0.to_string()),
|| "<nac3_internal>".to_string(),
|f| f.location.file.0.to_string(),
),
/* directory */ "", /* directory */ "",
/* producer */ "NAC3", /* producer */ "NAC3",
/* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None, /* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None,
@ -819,7 +947,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
inkwell::debug_info::DIFlags::PUBLIC, inkwell::debug_info::DIFlags::PUBLIC,
); );
let (row, col) = let (row, col) =
task.body.get(0).map_or_else(|| (0, 0), |b| (b.location.row, b.location.column)); task.body.first().map_or_else(|| (0, 0), |b| (b.location.row, b.location.column));
let func_scope: DISubprogram<'_> = dibuilder.create_function( let func_scope: DISubprogram<'_> = dibuilder.create_function(
/* scope */ compile_unit.as_debug_info_scope(), /* scope */ compile_unit.as_debug_info_scope(),
/* func name */ symbol, /* func name */ symbol,
@ -866,15 +994,15 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
row as u32, row as u32,
col as u32, col as u32,
func_scope.as_debug_info_scope(), func_scope.as_debug_info_scope(),
None None,
); );
code_gen_context.builder.set_current_debug_location(loc); code_gen_context.builder.set_current_debug_location(loc);
let result = codegen_function(generator, &mut code_gen_context); let result = codegen_function(generator, &mut code_gen_context);
// after static analysis, only void functions can have no return at the end. // after static analysis, only void functions can have no return at the end.
if !code_gen_context.is_terminated() { if !code_gen_context.is_terminated() {
code_gen_context.builder.build_return(None); code_gen_context.builder.build_return(None).unwrap();
} }
code_gen_context.builder.unset_current_debug_location(); code_gen_context.builder.unset_current_debug_location();
@ -916,12 +1044,14 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
if bool_value.get_type().get_bit_width() == 1 { if bool_value.get_type().get_bit_width() == 1 {
bool_value bool_value
} else { } else {
builder.build_int_compare( builder
IntPredicate::NE, .build_int_compare(
bool_value, IntPredicate::NE,
bool_value.get_type().const_zero(), bool_value,
"tobool" bool_value.get_type().const_zero(),
) "tobool",
)
.unwrap()
} }
} }
@ -929,21 +1059,23 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
fn bool_to_i8<'ctx>( fn bool_to_i8<'ctx>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
ctx: &'ctx Context, ctx: &'ctx Context,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let value_bits = bool_value.get_type().get_bit_width(); let value_bits = bool_value.get_type().get_bit_width();
match value_bits { match value_bits {
8 => bool_value, 8 => bool_value,
1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool"), 1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool").unwrap(),
_ => bool_to_i8( _ => bool_to_i8(
builder, builder,
ctx, ctx,
builder.build_int_compare( builder
IntPredicate::NE, .build_int_compare(
bool_value, IntPredicate::NE,
bool_value.get_type().const_zero(), bool_value,
"" bool_value.get_type().const_zero(),
) "",
)
.unwrap(),
), ),
} }
} }
@ -969,9 +1101,129 @@ fn gen_in_range_check<'ctx>(
stop: IntValue<'ctx>, stop: IntValue<'ctx>,
step: IntValue<'ctx>, step: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), ""); let sign = ctx
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value(); .builder
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value(); .build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "")
.unwrap();
let lo = ctx
.builder
.build_select(sign, value, stop, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let hi = ctx
.builder
.build_select(sign, stop, value, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
}
/// Returns the internal name for the `va_count` argument, used to indicate the number of arguments
/// passed to the variadic function.
fn get_va_count_arg_name(arg_name: StrRef) -> StrRef {
format!("__{}_va_count", &arg_name).into()
}
/// Returns the alignment of the type.
///
/// This is necessary as `get_alignment` is not implemented as part of [`BasicType`].
pub fn get_type_alignment<'ctx>(ty: impl Into<BasicTypeEnum<'ctx>>) -> IntValue<'ctx> {
match ty.into() {
BasicTypeEnum::ArrayType(ty) => ty.get_alignment(),
BasicTypeEnum::FloatType(ty) => ty.get_alignment(),
BasicTypeEnum::IntType(ty) => ty.get_alignment(),
BasicTypeEnum::PointerType(ty) => ty.get_alignment(),
BasicTypeEnum::StructType(ty) => ty.get_alignment(),
BasicTypeEnum::VectorType(ty) => ty.get_alignment(),
}
}
/// Inserts an `alloca` instruction with allocation `size` given in bytes and the alignment of the
/// given type.
///
/// The returned [`PointerValue`] will have a type of `i8*`, a size of at least `size`, and will be
/// aligned with the alignment of `align_ty`.
pub fn type_aligned_alloca<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
align_ty: impl Into<BasicTypeEnum<'ctx>>,
size: IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
/// Round `val` up to its modulo `power_of_two`.
fn round_up<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
power_of_two: IntValue<'ctx>,
) -> IntValue<'ctx> {
debug_assert_eq!(
val.get_type().get_bit_width(),
power_of_two.get_type().get_bit_width(),
"`val` ({}) and `power_of_two` ({}) must be the same type",
val.get_type(),
power_of_two.get_type(),
);
let llvm_val_t = val.get_type();
let max_rem =
ctx.builder.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "").unwrap();
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
}
let llvm_i8 = ctx.ctx.i8_type();
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let llvm_usize = generator.get_size_type(ctx.ctx);
let align_ty = align_ty.into();
let size = ctx.builder.build_int_truncate_or_bit_cast(size, llvm_usize, "").unwrap();
debug_assert_eq!(
size.get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected size_t ({}) for parameter `size` of `aligned_alloca`, got {}",
llvm_usize,
size.get_type(),
);
let alignment = get_type_alignment(align_ty);
let alignment = ctx.builder.build_int_truncate_or_bit_cast(alignment, llvm_usize, "").unwrap();
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let alignment_bitcount = llvm_intrinsics::call_int_ctpop(ctx, alignment, None);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::EQ,
alignment_bitcount,
alignment_bitcount.get_type().const_int(1, false),
"",
)
.unwrap(),
"0:AssertionError",
"Expected power-of-two alignment for aligned_alloca, got {0}",
[Some(alignment), None, None],
ctx.current_loc,
);
}
let buffer_size = round_up(ctx, size, alignment);
let aligned_slices = ctx.builder.build_int_unsigned_div(buffer_size, alignment, "").unwrap();
// Just to be absolutely sure, alloca in [i8 x alignment] slices
let buffer = ctx.builder.build_array_alloca(align_ty, aligned_slices, "").unwrap();
ctx.builder
.build_bit_cast(buffer, llvm_pi8, name.unwrap_or_default())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,29 +1,37 @@
use crate::{ use std::{
codegen::{ collections::{HashMap, HashSet},
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenLLVMOptions, sync::Arc,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
}; };
use indexmap::IndexMap;
use indoc::indoc; use indoc::indoc;
use inkwell::{ use inkwell::{
targets::{InitializationConfig, Target}, targets::{InitializationConfig, Target},
OptimizationLevel OptimizationLevel,
}; };
use nac3parser::{ use nac3parser::{
ast::{fold::Fold, StrRef}, ast::{fold::Fold, FileName, StrRef},
parser::parse_program, parser::parse_program,
}; };
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc; use super::{
concrete_type::ConcreteTypeStore,
types::{ndarray::NDArrayType, ListType, ProxyType, RangeType},
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator,
DefaultCodeGenerator, WithCall, WorkerRegistry,
};
use crate::{
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
struct Resolver { struct Resolver {
id_to_type: HashMap<StrRef, Type>, id_to_type: HashMap<StrRef, Type>,
@ -52,13 +60,14 @@ impl SymbolResolver for Resolver {
_: &PrimitiveStore, _: &PrimitiveStore,
str: StrRef, str: StrRef,
) -> Result<Type, String> { ) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str)) self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
} }
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, 'a>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
@ -67,10 +76,8 @@ impl SymbolResolver for Resolver {
self.id_to_def self.id_to_def
.read() .read()
.get(&id) .get(&id)
.cloned() .copied()
.ok_or_else(|| HashSet::from([ .ok_or_else(|| HashSet::from([format!("cannot find symbol `{id}`")]))
format!("cannot find symbol `{}`", id),
]))
} }
fn get_string_id(&self, _: &str) -> i32 { fn get_string_id(&self, _: &str) -> i32 {
@ -89,9 +96,9 @@ fn test_primitives() {
d = a if c == 1 else 0 d = a if c == 1 else 0
return d return d
"}; "};
let statements = parse_program(source, Default::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let composer: TopLevelComposer = Default::default(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -100,17 +107,27 @@ fn test_primitives() {
let resolver = Arc::new(Resolver { let resolver = Arc::new(Resolver {
id_to_type: HashMap::new(), id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(), class_names: HashMap::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }, FuncArg {
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None }, name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
], ],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: VarMap::new(),
}; };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -125,12 +142,13 @@ fn test_primitives() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect(); let mut identifiers: HashMap<_, _> =
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
unifier: &mut unifier, unifier: &mut unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
@ -154,7 +172,7 @@ fn test_primitives() {
}); });
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Vec::default(),
symbol_name: "testing".into(), symbol_name: "testing".into(),
body: Arc::new(statements), body: Arc::new(statements),
unifier_index: 0, unifier_index: 0,
@ -186,6 +204,8 @@ fn test_primitives() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 { define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
@ -225,12 +245,7 @@ fn test_primitives() {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(), target: CodeGenTargetMachineOptions::from_host_triple(),
}; };
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
@ -241,23 +256,28 @@ fn test_simple_call() {
a = foo(a) a = foo(a)
return a * 2 return a * 2
"}; "};
let statements_1 = parse_program(source_1, Default::default()).unwrap(); let statements_1 = parse_program(source_1, FileName::default()).unwrap();
let source_2 = indoc! { " let source_2 = indoc! { "
return a + 1 return a + 1
"}; "};
let statements_2 = parse_program(source_2, Default::default()).unwrap(); let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let composer: TopLevelComposer = Default::default(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone()); unifier.top_level = Some(top_level.clone());
let signature = FunSignature { let signature = FunSignature {
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }], args: vec![FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: VarMap::new(),
}; };
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone())); let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -281,7 +301,7 @@ fn test_simple_call() {
let resolver = Resolver { let resolver = Resolver {
id_to_type: HashMap::new(), id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(), class_names: HashMap::default(),
}; };
resolver.add_id_def("foo".into(), DefinitionId(foo_id)); resolver.add_id_def("foo".into(), DefinitionId(foo_id));
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>; let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
@ -302,12 +322,13 @@ fn test_simple_call() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect(); let mut identifiers: HashMap<_, _> =
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
unifier: &mut unifier, unifier: &mut unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
@ -336,11 +357,11 @@ fn test_simple_call() {
&mut *top_level.definitions.read()[foo_id].write() &mut *top_level.definitions.read()[foo_id].write()
{ {
instance_to_stmt.insert( instance_to_stmt.insert(
"".to_string(), String::new(),
FunInstance { FunInstance {
body: Arc::new(statements_2), body: Arc::new(statements_2),
calls: Arc::new(inferencer.calls.clone()), calls: Arc::new(inferencer.calls.clone()),
subst: Default::default(), subst: IndexMap::default(),
unifier_id: 0, unifier_id: 0,
}, },
); );
@ -356,7 +377,7 @@ fn test_simple_call() {
}); });
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Vec::default(),
symbol_name: "testing".to_string(), symbol_name: "testing".to_string(),
body: Arc::new(statements_1), body: Arc::new(statements_1),
calls: Arc::new(calls1), calls: Arc::new(calls1),
@ -370,6 +391,8 @@ fn test_simple_call() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 { define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {
@ -415,12 +438,39 @@ fn test_simple_call() {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(), target: CodeGenTargetMachineOptions::from_host_triple(),
}; };
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
#[test]
fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_representable(llvm_list.as_base_type(), llvm_usize).is_ok());
}
#[test]
fn test_classes_range_type_new() {
let ctx = inkwell::context::Context::create();
let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_representable(llvm_range.as_base_type()).is_ok());
}
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into(), None);
assert!(NDArrayType::is_representable(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}

View File

@ -0,0 +1,206 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
use super::ProxyType;
use crate::codegen::{
values::{ArraySliceValue, ListValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
/// Proxy type for a `list` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ListType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
impl<'ctx> ListType<'ctx> {
/// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_list_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"));
};
if llvm_list_ty.count_fields() != 2 {
return Err(format!(
"Expected 2 fields in `list`, got {}",
llvm_list_ty.count_fields()
));
}
let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap();
let Ok(_) = PointerType::try_from(list_size_ty) else {
return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}"));
};
let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap();
let Ok(list_data_ty) = IntType::try_from(list_data_ty) else {
return Err(format!("Expected int type for `list.1`, got {list_data_ty}"));
};
if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!(
"Expected {}-bit int type for `list.1`, got {}-bit int",
llvm_usize.get_bit_width(),
list_data_ty.get_bit_width()
));
}
Ok(())
}
/// Creates an LLVM type corresponding to the expected structure of a `List`.
#[must_use]
fn llvm_type(
ctx: &'ctx Context,
element_type: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> PointerType<'ctx> {
// struct List { data: T*, size: size_t }
let field_tys = [element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()];
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`ListType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
element_type: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_list = Self::llvm_type(ctx, element_type, llvm_usize);
ListType::from_type(llvm_list, llvm_usize)
}
/// Creates an [`ListType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
ListType { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `list` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(1)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `list` type.
#[must_use]
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
}
/// Allocates an instance of [`ListValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.llvm_usize,
name,
)
}
/// Converts an existing value into a [`ListValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
}
}
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ListValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ListType<'ctx>> for PointerType<'ctx> {
fn from(value: ListType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,76 @@
//! This module contains abstraction over all intrinsic composite types of NAC3.
//!
//! # `raw_alloca` vs `alloca` vs `construct`
//!
//! There are three ways of creating a new object instance using the abstractions provided by this
//! module.
//!
//! - `raw_alloca`: Allocates the object on the stack, returning an instance of
//! [`impl BasicValue`][inkwell::values::BasicValue]. This is similar to a `malloc` expression in
//! C++ but the object is allocated on the stack.
//! - `alloca`: Similar to `raw_alloca`, but also wraps the allocated object with
//! [`<Self as ProxyType<'ctx>>::Value`][ProxyValue], and returns the wrapped object. The returned
//! object will not initialize any value or fields. This is similar to a type-safe `malloc`
//! expression in C++ but the object is allocated on the stack.
//! - `construct`: Similar to `alloca`, but performs some initialization on the value or fields of
//! the returned object. This is similar to a `new` expression in C++ but the object is allocated
//! on the stack.
use inkwell::{context::Context, types::BasicType, values::IntValue};
use super::{
values::{ArraySliceValue, ProxyValue},
{CodeGenContext, CodeGenerator},
};
pub use list::*;
pub use range::*;
mod list;
pub mod ndarray;
mod range;
pub mod structure;
pub mod utils;
/// A LLVM type that is used to represent a corresponding type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {
/// The LLVM type of which values of this type possess. This is usually a
/// [LLVM pointer type][PointerType] for any non-primitive types.
type Base: BasicType<'ctx>;
/// The type of values represented by this type.
type Value: ProxyValue<'ctx, Type = Self>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String>;
/// Checks whether `llvm_ty` can be represented by this [`ProxyType`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String>;
/// Creates a new value of this type, returning the LLVM instance of this value.
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base;
/// Creates a new array value of this type, returning an [`ArraySliceValue`] encapsulating the
/// resulting array.
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx>;
/// Returns the [base type][Self::Base] of this proxy.
fn as_base_type(&self) -> Self::Base;
}

View File

@ -0,0 +1,257 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use crate::{
codegen::{
types::{
structure::{
check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields,
},
ProxyType,
},
values::{ndarray::ContiguousNDArrayValue, ArraySliceValue, ProxyValue},
CodeGenContext, CodeGenerator,
},
toplevel::numpy::unpack_ndarray_var_tys,
typecheck::typedef::Type,
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ContiguousNDArrayType<'ctx> {
ty: PointerType<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct ContiguousNDArrayFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> ContiguousNDArrayFields<'ctx> {
#[must_use]
pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let mut counter = FieldIndexCounter::default();
ContiguousNDArrayFields {
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
data: StructField::create(&mut counter, "data", item.ptr_type(AddressSpace::default())),
}
}
}
impl<'ctx> ContiguousNDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else {
return Err(format!(
"Expected struct type for `ContiguousNDArray` type, got {llvm_ty}"
));
};
let fields = ContiguousNDArrayFields::new(ctx, llvm_usize);
check_struct_type_matches_fields(
fields,
llvm_ty,
"ContiguousNDArray",
&[(fields.data.name(), &|ty| {
if ty.is_pointer_type() {
Ok(())
} else {
Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}"))
}
})],
)
}
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
#[must_use]
fn fields(
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> ContiguousNDArrayFields<'ctx> {
ContiguousNDArrayFields::new_typed(item, llvm_usize)
}
/// See [`NDArrayType::fields`].
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> {
Self::fields(self.item, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(
ctx: &'ctx Context,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> PointerType<'ctx> {
let field_tys =
Self::fields(item, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`ContiguousNDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
item: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize);
Self { ty: llvm_cndarray, item, llvm_usize }
}
/// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize }
}
/// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
Self { ty: ptr_ty, item, llvm_usize }
}
/// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.item,
self.llvm_usize,
name,
)
}
/// Converts an existing value into a [`ContiguousNDArrayValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
value,
self.item,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ContiguousNDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ContiguousNDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: ContiguousNDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,215 @@
use inkwell::{
context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use crate::codegen::{
types::{
structure::{check_struct_type_matches_fields, StructField, StructFields},
ProxyType,
},
values::{
ndarray::{NDIndexValue, RustNDIndex},
ArrayLikeIndexer, ArraySliceValue, ProxyValue,
},
CodeGenContext, CodeGenerator,
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDIndexType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDIndexStructFields<'ctx> {
#[value_type(i8_type())]
pub type_: StructField<'ctx, IntValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> NDIndexType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndindex` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else {
return Err(format!(
"Expected struct type for `ContiguousNDArray` type, got {llvm_ty}"
));
};
let fields = NDIndexStructFields::new(ctx, llvm_usize);
check_struct_type_matches_fields(fields, llvm_ty, "NDIndex", &[])
}
#[must_use]
fn fields(
ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>,
) -> NDIndexStructFields<'ctx> {
NDIndexStructFields::new(ctx, llvm_usize)
}
#[must_use]
pub fn get_fields(&self) -> NDIndexStructFields<'ctx> {
Self::fields(self.ty.get_context(), self.llvm_usize)
}
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndindex = Self::llvm_type(ctx, llvm_usize);
Self { ty: llvm_ndindex, llvm_usize }
}
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
Self { ty: ptr_ty, llvm_usize }
}
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.llvm_usize,
name,
)
}
/// Serialize a list of [`RustNDIndex`] as a newly allocated LLVM array of [`NDIndexValue`].
#[must_use]
pub fn construct_ndindices<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
in_ndindices: &[RustNDIndex<'ctx>],
) -> ArraySliceValue<'ctx> {
// Allocate the LLVM ndindices.
let num_ndindices = self.llvm_usize.const_int(in_ndindices.len() as u64, false);
let ndindices = self.array_alloca(generator, ctx, num_ndindices, None);
// Initialize all of them.
for (i, in_ndindex) in in_ndindices.iter().enumerate() {
let pndindex = unsafe {
ndindices.ptr_offset_unchecked(
ctx,
generator,
&ctx.ctx.i64_type().const_int(u64::try_from(i).unwrap(), false),
None,
)
};
in_ndindex.write_to_ndindex(
generator,
ctx,
NDIndexValue::from_pointer_value(pndindex, self.llvm_usize, None),
);
}
ndindices
}
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, self.llvm_usize, name)
}
}
impl<'ctx> ProxyType<'ctx> for NDIndexType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDIndexValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDIndexType<'ctx>> for PointerType<'ctx> {
fn from(value: NDIndexType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,469 @@
use inkwell::{
context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{BasicValue, IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::{
structure::{check_struct_type_matches_fields, StructField, StructFields},
ProxyType,
};
use crate::{
codegen::{
values::{ndarray::NDArrayValue, ArraySliceValue, ProxyValue, TypedArrayLikeMutator},
{CodeGenContext, CodeGenerator},
},
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
typecheck::typedef::Type,
};
pub use contiguous::*;
pub use indexing::*;
pub use nditer::*;
mod contiguous;
mod indexing;
mod nditer;
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDArrayStructFields<'ctx> {
/// The size of each `NDArray` element in bytes.
#[value_type(usize)]
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
/// Number of dimensions in the array.
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
/// Pointer to an array containing the shape of the `NDArray`.
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
/// Pointer to an array indicating the number of bytes between each element at a dimension
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub strides: StructField<'ctx, PointerValue<'ctx>>,
/// Pointer to an array containing the array data
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
check_struct_type_matches_fields(
Self::fields(ctx, llvm_usize),
llvm_ndarray_ty,
"NDArray",
&[],
)
}
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
#[must_use]
fn fields(
ctx: impl AsContextRef<'ctx>,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
}
/// See [`NDArrayType::fields`].
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`NDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, ndims, llvm_usize }
}
/// Creates an instance of [`NDArrayType`] with `ndims` of 0.
#[must_use]
pub fn new_unsized<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_ndarray = Self::llvm_type(ctx, llvm_usize);
NDArrayType { ty: llvm_ndarray, dtype, ndims: Some(0), llvm_usize }
}
/// Creates an [`NDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndims = extract_ndims(&ctx.unifier, ndims);
NDArrayType {
ty: Self::llvm_type(ctx.ctx, llvm_usize),
dtype: llvm_dtype,
ndims: Some(ndims),
llvm_usize,
}
}
/// Creates an [`NDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, dtype, ndims, llvm_usize }
}
/// Returns the type of the `size` field of this `ndarray` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.llvm_usize
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.dtype
}
/// Returns the number of dimensions of this `ndarray` type.
#[must_use]
pub fn ndims(&self) -> Option<u64> {
self.ndims
}
/// Allocates an instance of [`NDArrayValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.dtype,
self.ndims,
self.llvm_usize,
name,
)
}
/// Allocates an [`NDArrayValue`] on the stack and initializes all fields as follows:
///
/// - `data`: uninitialized.
/// - `itemsize`: set to the size of `self.dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated on the stack with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated on the stack with an array of length `ndims` with uninitialized
/// values.
#[must_use]
fn construct_impl<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
let ndarray = self.alloca(generator, ctx, name);
let itemsize = ctx
.builder
.build_int_truncate_or_bit_cast(self.dtype.size_of().unwrap(), self.llvm_usize, "")
.unwrap();
ndarray.store_itemsize(ctx, generator, itemsize);
ndarray.store_ndims(ctx, generator, ndims);
ndarray.create_shape(ctx, self.llvm_usize, ndims);
ndarray.create_strides(ctx, self.llvm_usize, ndims);
ndarray
}
/// Allocate an [`NDArrayValue`] on the stack using `dtype` and `ndims` of this [`NDArrayType`]
/// instance.
///
/// The returned ndarray's content will be:
/// - `data`: uninitialized.
/// - `itemsize`: set to the size of `dtype`.
/// - `ndims`: set to the value of `self.ndims`.
/// - `shape`: allocated on the stack with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated on the stack with an array of length `ndims` with uninitialized
/// values.
#[must_use]
pub fn construct_uninitialized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
assert!(self.ndims.is_some(), "NDArrayType::construct can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
let Some(ndims) = self.ndims.map(|ndims| self.llvm_usize.const_int(ndims, false)) else {
unreachable!()
};
self.construct_impl(generator, ctx, ndims, name)
}
/// Allocate an [`NDArrayValue`] on the stack given its `ndims` and `dtype`.
///
/// `shape` and `strides` will be automatically allocated onto the stack.
///
/// The returned ndarray's content will be:
/// - `data`: uninitialized.
/// - `itemsize`: set to the size of `dtype`.
/// - `ndims`: set to the value of `ndims`.
/// - `shape`: allocated with an array of length `ndims` with uninitialized values.
/// - `strides`: allocated with an array of length `ndims` with uninitialized values.
#[deprecated = "Prefer construct_uninitialized or construct_*_shape."]
#[must_use]
pub fn construct_dyn_ndims<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndims: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
assert!(self.ndims.is_none(), "NDArrayType::construct_dyn_ndims can only be called on an instance with compile-time unknown ndims (self.ndims = None)");
self.construct_impl(generator, ctx, ndims, name)
}
/// Convenience function. Allocate an [`NDArrayValue`] with a statically known shape.
///
/// The returned [`NDArrayValue`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_const_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[u64],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims));
let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64))
.construct_uninitialized(generator, ctx, name);
let llvm_usize = generator.get_size_type(ctx.ctx);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
let dim = llvm_usize.const_int(*dim, false);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
dim,
);
}
}
ndarray
}
/// Convenience function. Allocate an [`NDArrayValue`] with a dynamically known shape.
///
/// The returned [`NDArrayValue`]'s `data` and `strides` are uninitialized.
#[must_use]
pub fn construct_dyn_shape<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: &[IntValue<'ctx>],
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
assert!(self.ndims.is_none_or(|ndims| shape.len() as u64 == ndims));
let ndarray = Self::new(generator, ctx.ctx, self.dtype, Some(shape.len() as u64))
.construct_uninitialized(generator, ctx, name);
let llvm_usize = generator.get_size_type(ctx.ctx);
// Write shape
let ndarray_shape = ndarray.shape();
for (i, dim) in shape.iter().enumerate() {
assert_eq!(
dim.get_type(),
llvm_usize,
"Expected {} but got {}",
llvm_usize.print_to_string(),
dim.get_type().print_to_string()
);
unsafe {
ndarray_shape.set_typed_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
*dim,
);
}
}
ndarray
}
/// Create an unsized ndarray to contain `value`.
#[must_use]
pub fn construct_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
value: &impl BasicValue<'ctx>,
name: Option<&'ctx str>,
) -> NDArrayValue<'ctx> {
let value = value.as_basic_value_enum();
assert_eq!(value.get_type(), self.dtype);
assert!(self.ndims.is_none_or(|ndims| ndims == 0));
// We have to put the value on the stack to get a data pointer.
let data = ctx.builder.build_alloca(value.get_type(), "construct_unsized").unwrap();
ctx.builder.build_store(data, value).unwrap();
let data = ctx
.builder
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap();
let ndarray = Self::new_unsized(generator, ctx.ctx, value.get_type())
.construct_uninitialized(generator, ctx, name);
ctx.builder.build_store(ndarray.ptr_to_data(ctx), data).unwrap();
ndarray
}
/// Converts an existing value into a [`NDArrayValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
value,
self.dtype,
self.ndims,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,241 @@
use inkwell::{
context::{AsContextRef, Context},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::ProxyType;
use crate::codegen::{
irrt,
types::structure::{check_struct_type_matches_fields, StructField, StructFields},
values::{
ndarray::{NDArrayValue, NDIterValue},
ArraySliceValue, ProxyValue,
},
CodeGenContext, CodeGenerator,
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDIterType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct NDIterStructFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub strides: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub indices: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
pub nth: StructField<'ctx, IntValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub element: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(usize)]
pub size: StructField<'ctx, IntValue<'ctx>>,
}
impl<'ctx> NDIterType<'ctx> {
/// Checks whether `llvm_ty` represents a `nditer` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let llvm_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ty else {
return Err(format!("Expected struct type for `NDIter` type, got {llvm_ty}"));
};
check_struct_type_matches_fields(
Self::fields(ctx, llvm_usize),
llvm_ndarray_ty,
"NDIter",
&[],
)
}
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
#[must_use]
fn fields(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> NDIterStructFields<'ctx> {
NDIterStructFields::new(ctx, llvm_usize)
}
/// See [`NDIterType::fields`].
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self, ctx: impl AsContextRef<'ctx>) -> NDIterStructFields<'ctx> {
Self::fields(ctx, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDIter`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> PointerType<'ctx> {
let field_tys =
Self::fields(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`NDIter`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_nditer = Self::llvm_type(ctx, llvm_usize);
Self { ty: llvm_nditer, llvm_usize }
}
/// Creates an [`NDIterType`] from a [`PointerType`] representing an `NDIter`.
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
Self { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `nditer` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.llvm_usize
}
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
parent,
indices,
self.llvm_usize,
name,
)
}
/// Allocate an [`NDIter`] that iterates through the given `ndarray`.
#[must_use]
pub fn construct<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
) -> <Self as ProxyType<'ctx>>::Value {
let nditer = self.raw_alloca(generator, ctx, None);
let ndims = ndarray.load_ndims(ctx);
// The caller has the responsibility to allocate 'indices' for `NDIter`.
let indices =
generator.gen_array_var_alloc(ctx, self.llvm_usize.into(), ndims, None).unwrap();
let nditer = <Self as ProxyType<'ctx>>::Value::from_pointer_value(
nditer,
ndarray,
indices,
self.llvm_usize,
None,
);
irrt::ndarray::call_nac3_nditer_initialize(generator, ctx, nditer, ndarray, indices);
nditer
}
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
value,
parent,
indices,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for NDIterType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDIterValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDIterType<'ctx>> for PointerType<'ctx> {
fn from(value: NDIterType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,170 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
use super::ProxyType;
use crate::codegen::{
values::{ArraySliceValue, ProxyValue, RangeValue},
{CodeGenContext, CodeGenerator},
};
/// Proxy type for a `range` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct RangeType<'ctx> {
ty: PointerType<'ctx>,
}
impl<'ctx> RangeType<'ctx> {
/// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not.
pub fn is_representable(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
let llvm_range_ty = llvm_ty.get_element_type();
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"));
};
if llvm_range_ty.len() != 3 {
return Err(format!(
"Expected 3 elements for `range` type, got {}",
llvm_range_ty.len()
));
}
let llvm_range_elem_ty = llvm_range_ty.get_element_type();
let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else {
return Err(format!(
"Expected int type for `range` element type, got {llvm_range_elem_ty}"
));
};
if llvm_range_elem_ty.get_bit_width() != 32 {
return Err(format!(
"Expected 32-bit int type for `range` element type, got {}",
llvm_range_elem_ty.get_bit_width()
));
}
Ok(())
}
/// Creates an LLVM type corresponding to the expected structure of a `Range`.
#[must_use]
fn llvm_type(ctx: &'ctx Context) -> PointerType<'ctx> {
// typedef int32_t Range[3];
let llvm_i32 = ctx.i32_type();
llvm_i32.array_type(3).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`RangeType`].
#[must_use]
pub fn new(ctx: &'ctx Context) -> Self {
let llvm_range = Self::llvm_type(ctx);
RangeType::from_type(llvm_range)
}
/// Creates an [`RangeType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self {
debug_assert!(Self::is_representable(ptr_ty).is_ok());
RangeType { ty: ptr_ty }
}
/// Returns the type of all fields of this `range` type.
#[must_use]
pub fn value_type(&self) -> IntType<'ctx> {
self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type()
}
/// Allocates an instance of [`RangeValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
name,
)
}
/// Converts an existing value into a [`RangeValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(value, name)
}
}
impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
type Base = PointerType<'ctx>;
type Value = RangeValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
_: &G,
_: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty)
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<RangeType<'ctx>> for PointerType<'ctx> {
fn from(value: RangeType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,255 @@
use std::marker::PhantomData;
use inkwell::{
context::AsContextRef,
types::{BasicTypeEnum, IntType, StructType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
};
use crate::codegen::CodeGenContext;
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
///
/// # Usage
///
/// For example, for a simple C-slice LLVM structure:
///
/// ```ignore
/// struct CSliceFields<'ctx> {
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
/// len: StructField<'ctx, IntValue<'ctx>>
/// }
/// ```
pub trait StructFields<'ctx>: Eq + Copy {
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
fn new(ctx: impl AsContextRef<'ctx>, llvm_usize: IntType<'ctx>) -> Self;
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
self.to_vec().into_iter()
}
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.to_vec()
}
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.into_vec().into_iter()
}
}
/// A single field of an LLVM structure.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// The index of this field within the structure.
index: u32,
/// The name of this field.
name: &'static str,
/// The type of this field.
ty: BasicTypeEnum<'ctx>,
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
_value_ty: PhantomData<Value>,
}
impl<'ctx, Value> StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// Creates an instance of [`StructField`].
///
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
/// index.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub fn create(
idx_counter: &mut FieldIndexCounter,
name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>,
) -> Self {
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates an instance of [`StructField`] with a given index.
///
/// * `index` - The index of this field within its enclosing structure.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub fn create_at(index: u32, name: &'static str, ty: impl Into<BasicTypeEnum<'ctx>>) -> Self {
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
}
/// Returns the name of this field.
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
idx: &[IntValue<'ctx>],
) -> PointerValue<'ctx> {
unsafe {
ctx.builder.build_in_bounds_gep(
pobj,
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
"",
)
}
.unwrap()
}
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
/// `getelementptr i32 0, i32 {self.index}`.
pub fn ptr_by_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> PointerValue<'ctx> {
ctx.builder
.build_struct_gep(
pobj,
self.index,
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
)
.unwrap()
}
/// Gets the value of this field for a given `obj`.
#[must_use]
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value {
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap()
}
/// Sets the value of this field for a given `obj`.
pub fn set_for_value(&self, obj: StructValue<'ctx>, value: Value) {
obj.set_field_at_index(self.index, value);
}
/// Gets the value of this field for a pointer-to-structure.
pub fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> Value {
ctx.builder
.build_load(
self.ptr_by_gep(ctx, pobj, obj_name),
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
)
.map_err(|_| ())
.and_then(|value| Value::try_from(value))
.unwrap()
}
/// Sets the value of this field for a pointer-to-structure.
pub fn set(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
value: Value,
obj_name: Option<&'ctx str>,
) {
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
}
}
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
fn from(value: StructField<'ctx, Value>) -> Self {
(value.name, value.ty)
}
}
/// A counter that tracks the next index of a field using a monotonically increasing counter.
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
pub struct FieldIndexCounter(u32);
impl FieldIndexCounter {
/// Increments the number stored by this counter, returning the previous value.
///
/// Functionally equivalent to `i++` in C-based languages.
pub fn increment(&mut self) -> u32 {
let v = self.0;
self.0 += 1;
v
}
}
type FieldTypeVerifier<'ctx> = dyn Fn(BasicTypeEnum<'ctx>) -> Result<(), String>;
/// Checks whether [`llvm_ty`][StructType] contains the fields described by the given
/// [`StructFields`] instance.
///
/// By default, this function will compare the type of each field in `expected_fields` against
/// `llvm_ty`. To override this behavior for individual fields, pass in overrides to
/// `custom_verifiers`, which will use the specified verifier when a field with the matching field
/// name is being checked.
pub(super) fn check_struct_type_matches_fields<'ctx>(
expected_fields: impl StructFields<'ctx>,
llvm_ty: StructType<'ctx>,
ty_name: &'static str,
custom_verifiers: &[(&str, &FieldTypeVerifier<'ctx>)],
) -> Result<(), String> {
let expected_fields = expected_fields.to_vec();
if llvm_ty.count_fields() != u32::try_from(expected_fields.len()).unwrap() {
return Err(format!(
"Expected {} fields in `{ty_name}`, got {}",
expected_fields.len(),
llvm_ty.count_fields(),
));
}
expected_fields
.into_iter()
.enumerate()
.map(|(i, (field_name, expected_ty))| {
(field_name, expected_ty, llvm_ty.get_field_type_at_index(i as u32).unwrap())
})
.try_for_each(|(field_name, expected_ty, actual_ty)| {
if let Some((_, verifier)) =
custom_verifiers.iter().find(|verifier| verifier.0 == field_name)
{
verifier(actual_ty)
} else if expected_ty == actual_ty {
Ok(())
} else {
Err(format!("Expected {expected_ty} for `{ty_name}.{field_name}`, got {actual_ty}"))
}
})?;
Ok(())
}

View File

@ -0,0 +1,3 @@
pub use slice::*;
mod slice;

View File

@ -0,0 +1,254 @@
use inkwell::{
context::{AsContextRef, Context, ContextRef},
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::IntValue,
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use crate::codegen::{
types::{
structure::{
check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields,
},
ProxyType,
},
values::{utils::SliceValue, ArraySliceValue, ProxyValue},
CodeGenContext, CodeGenerator,
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct SliceType<'ctx> {
ty: PointerType<'ctx>,
int_ty: IntType<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct SliceFields<'ctx> {
#[value_type(bool_type())]
pub start_defined: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize)]
pub start: StructField<'ctx, IntValue<'ctx>>,
#[value_type(bool_type())]
pub stop_defined: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize)]
pub stop: StructField<'ctx, IntValue<'ctx>>,
#[value_type(bool_type())]
pub step_defined: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize)]
pub step: StructField<'ctx, IntValue<'ctx>>,
}
impl<'ctx> SliceFields<'ctx> {
/// Creates a new instance of [`SliceFields`] with a custom integer type for its range values.
#[must_use]
pub fn new_sized(ctx: &impl AsContextRef<'ctx>, int_ty: IntType<'ctx>) -> Self {
let ctx = unsafe { ContextRef::new(ctx.as_ctx_ref()) };
let mut counter = FieldIndexCounter::default();
SliceFields {
start_defined: StructField::create(&mut counter, "start_defined", ctx.bool_type()),
start: StructField::create(&mut counter, "start", int_ty),
stop_defined: StructField::create(&mut counter, "stop_defined", ctx.bool_type()),
stop: StructField::create(&mut counter, "stop", int_ty),
step_defined: StructField::create(&mut counter, "step_defined", ctx.bool_type()),
step: StructField::create(&mut counter, "step", int_ty),
}
}
}
impl<'ctx> SliceType<'ctx> {
/// Checks whether `llvm_ty` represents a `slice` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let fields = SliceFields::new(ctx, llvm_usize);
let llvm_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else {
return Err(format!("Expected struct type for `Slice` type, got {llvm_ty}"));
};
check_struct_type_matches_fields(
fields,
llvm_ty,
"Slice",
&[
(fields.start.name(), &|ty| {
if ty.is_int_type() {
Ok(())
} else {
Err(format!("Expected int type for `Slice.start`, got {ty}"))
}
}),
(fields.stop.name(), &|ty| {
if ty.is_int_type() {
Ok(())
} else {
Err(format!("Expected int type for `Slice.stop`, got {ty}"))
}
}),
(fields.step.name(), &|ty| {
if ty.is_int_type() {
Ok(())
} else {
Err(format!("Expected int type for `Slice.step`, got {ty}"))
}
}),
],
)
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self) -> SliceFields<'ctx> {
SliceFields::new_sized(&self.int_ty.get_context(), self.int_ty)
}
/// Creates an LLVM type corresponding to the expected structure of a `Slice`.
#[must_use]
fn llvm_type(ctx: &'ctx Context, int_ty: IntType<'ctx>) -> PointerType<'ctx> {
let field_tys = SliceFields::new_sized(&int_ty.get_context(), int_ty)
.into_iter()
.map(|field| field.1)
.collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`SliceType`] with `int_ty` as its backing integer type.
#[must_use]
pub fn new(ctx: &'ctx Context, int_ty: IntType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let llvm_ty = Self::llvm_type(ctx, int_ty);
Self { ty: llvm_ty, int_ty, llvm_usize }
}
/// Creates an instance of [`SliceType`] with `usize` as its backing integer type.
#[must_use]
pub fn new_usize<G: CodeGenerator + ?Sized>(generator: &G, ctx: &'ctx Context) -> Self {
let llvm_usize = generator.get_size_type(ctx);
Self::new(ctx, llvm_usize, llvm_usize)
}
/// Creates an [`SliceType`] from a [`PointerType`] representing a `slice`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
int_ty: IntType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, int_ty).is_ok());
Self { ty: ptr_ty, int_ty, llvm_usize }
}
#[must_use]
pub fn element_type(&self) -> IntType<'ctx> {
self.int_ty
}
/// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.int_ty,
self.llvm_usize,
name,
)
}
/// Converts an existing value into a [`ContiguousNDArrayValue`].
#[must_use]
pub fn map_value(
&self,
value: <<Self as ProxyType<'ctx>>::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
value,
self.int_ty,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for SliceType<'ctx> {
type Base = PointerType<'ctx>;
type Value = SliceValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<SliceType<'ctx>> for PointerType<'ctx> {
fn from(value: SliceType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -0,0 +1,426 @@
use inkwell::{
types::AnyTypeEnum,
values::{BasicValueEnum, IntValue, PointerValue},
IntPredicate,
};
use crate::codegen::{CodeGenContext, CodeGenerator};
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements.
pub trait ArrayLikeValue<'ctx> {
/// Returns the element type of this array-like value.
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx>;
/// Returns the base pointer to the array.
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> PointerValue<'ctx>;
/// Returns the size of this array-like value.
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx>;
/// Returns a [`ArraySliceValue`] representing this value.
fn as_slice_value<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> ArraySliceValue<'ctx> {
ArraySliceValue::from_ptr_val(
self.base_ptr(ctx, generator),
self.size(ctx, generator),
None,
)
}
}
/// An array-like value that can be indexed by memory offset.
pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
/// Returns the pointer to the data at the `idx`-th index.
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx>;
}
/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`].
pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn get_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
}
/// Returns the data at the `idx`-th index.
fn get<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> BasicValueEnum<'ctx> {
let ptr = self.ptr_offset(ctx, generator, idx, name);
ctx.builder.build_load(ptr, name.unwrap_or_default()).unwrap()
}
}
/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`].
pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn set_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: BasicValueEnum<'ctx>,
) {
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, None) };
ctx.builder.build_store(ptr, value).unwrap();
}
/// Sets the data at the `idx`-th index.
fn set<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: BasicValueEnum<'ctx>,
) {
let ptr = self.ptr_offset(ctx, generator, idx, None);
ctx.builder.build_store(ptr, value).unwrap();
}
}
/// An array-like value that can have its array elements accessed as an arbitrary type `T`.
pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeAccessor<'ctx, Index>
{
/// Casts an element from [`BasicValueEnum`] into `T`.
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T;
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn get_typed_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> T {
let value = unsafe { self.get_unchecked(ctx, generator, idx, name) };
self.downcast_to_type(ctx, value)
}
/// Returns the data at the `idx`-th index.
fn get_typed<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> T {
let value = self.get(ctx, generator, idx, name);
self.downcast_to_type(ctx, value)
}
}
/// An array-like value that can have its array elements mutated as an arbitrary type `T`.
pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeMutator<'ctx, Index>
{
/// Casts an element from T into [`BasicValueEnum`].
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx>;
/// # Safety
///
/// This function should be called with a valid index.
unsafe fn set_typed_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: T,
) {
let value = self.upcast_from_type(ctx, value);
unsafe { self.set_unchecked(ctx, generator, idx, value) }
}
/// Sets the data at the `idx`-th index.
fn set_typed<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
value: T,
) {
let value = self.upcast_from_type(ctx, value);
self.set(ctx, generator, idx, value);
}
}
/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`.
type ValueDowncastFn<'ctx, T> =
Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T + 'ctx>;
/// Type alias for a function that casts a `T` into a [`BasicValueEnum`].
type ValueUpcastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, T) -> BasicValueEnum<'ctx>>;
/// An adapter for constraining untyped array values as typed values.
pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArraySliceValue<'ctx>> {
adapted: Adapted,
downcast_fn: ValueDowncastFn<'ctx, T>,
upcast_fn: ValueUpcastFn<'ctx, T>,
}
impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeValue<'ctx>,
{
/// Creates a [`TypedArrayLikeAdapter`].
///
/// * `adapted` - The value to be adapted.
/// * `downcast_fn` - The function converting a [`BasicValueEnum`] into a `T`.
/// * `upcast_fn` - The function converting a T into a [`BasicValueEnum`].
pub fn from(
adapted: Adapted,
downcast_fn: ValueDowncastFn<'ctx, T>,
upcast_fn: ValueUpcastFn<'ctx, T>,
) -> Self {
TypedArrayLikeAdapter { adapted, downcast_fn, upcast_fn }
}
}
impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeValue<'ctx>,
{
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.adapted.element_type(ctx, generator)
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> PointerValue<'ctx> {
self.adapted.base_ptr(ctx, generator)
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
self.adapted.size(ctx, generator)
}
}
impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeIndexer<'ctx, Index>,
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
unsafe { self.adapted.ptr_offset_unchecked(ctx, generator, idx, name) }
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
self.adapted.ptr_offset(ctx, generator, idx, name)
}
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T {
(self.downcast_fn)(ctx, value)
}
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx> {
(self.upcast_fn)(ctx, value)
}
}
/// An LLVM value representing an array slice, consisting of a pointer to the data and the size of
/// the slice.
#[derive(Copy, Clone)]
pub struct ArraySliceValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>);
impl<'ctx> ArraySliceValue<'ctx> {
/// Creates an [`ArraySliceValue`] from a [`PointerValue`] and its size.
#[must_use]
pub fn from_ptr_val(
ptr: PointerValue<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Self {
ArraySliceValue(ptr, size, name)
}
}
impl<'ctx> From<ArraySliceValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ArraySliceValue<'ctx>) -> Self {
value.0
}
}
impl<'ctx> ArrayLikeValue<'ctx> for ArraySliceValue<'ctx> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
self.0
}
fn size<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.1
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"list index out of range",
[None, None, None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {}

View File

@ -0,0 +1,241 @@
use inkwell::{
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::{
types::ListType,
{CodeGenContext, CodeGenerator},
};
/// Proxy type for accessing a `list` value in LLVM.
#[derive(Copy, Clone)]
pub struct ListValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
ListType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
ListValue { value: ptr, llvm_usize, name }
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
/// Returns the pointer to the field storing the size of this `list`.
fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
ctx.builder.build_store(self.pptr_to_data(ctx), data).unwrap();
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
///
/// If `size` is [None], the size stored in the field of this instance is used instead.
pub fn create_data(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: BasicTypeEnum<'ctx>,
size: Option<IntValue<'ctx>>,
) {
let size = size.unwrap_or_else(|| self.load_size(ctx, None));
let data = ctx
.builder
.build_select(
ctx.builder
.build_int_compare(IntPredicate::NE, size, self.llvm_usize.const_zero(), "")
.unwrap(),
ctx.builder.build_array_alloca(elem_ty, size, "").unwrap(),
elem_ty.ptr_type(AddressSpace::default()).const_zero(),
"",
)
.map(BasicValueEnum::into_pointer_value)
.unwrap();
self.store_data(ctx, data);
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
#[must_use]
pub fn data(&self) -> ListDataProxy<'ctx, '_> {
ListDataProxy(self)
}
/// Stores the `size` of this `list` into this instance.
pub fn store_size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
size: IntValue<'ctx>,
) {
debug_assert_eq!(size.get_type(), generator.get_size_type(ctx.ctx));
let psize = self.ptr_to_size(ctx);
ctx.builder.build_store(psize, size).unwrap();
}
/// Returns the size of this `list` as a value.
pub fn load_size(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let psize = self.ptr_to_size(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.size")))
.unwrap_or_default();
ctx.builder
.build_load(psize, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
}
impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ListType<'ctx>;
fn get_type(&self) -> Self::Type {
ListType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ListValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing the `data` array of an `list` instance in LLVM.
#[derive(Copy, Clone)]
pub struct ListDataProxy<'ctx, 'a>(&'a ListValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.value.get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder
.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_size(ctx, None)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
debug_assert_eq!(idx.get_type(), generator.get_size_type(ctx.ctx));
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"list index out of range",
[None, None, None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {}

View File

@ -0,0 +1,47 @@
use inkwell::{context::Context, values::BasicValue};
use super::types::ProxyType;
use crate::codegen::CodeGenerator;
pub use array::*;
pub use list::*;
pub use range::*;
mod array;
mod list;
pub mod ndarray;
mod range;
pub mod utils;
/// A LLVM type that is used to represent a non-primitive value in NAC3.
pub trait ProxyValue<'ctx>: Into<Self::Base> {
/// The type of LLVM values represented by this instance. This is usually the
/// [LLVM pointer type][PointerValue].
type Base: BasicValue<'ctx>;
/// The type of this value.
type Type: ProxyType<'ctx, Value = Self>;
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_instance<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: impl BasicValue<'ctx>,
) -> Result<(), String> {
Self::Type::is_type(generator, ctx, value.as_basic_value_enum().get_type())
}
/// Checks whether `value` can be represented by this [`ProxyValue`].
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
value: Self::Base,
) -> Result<(), String> {
Self::is_instance(generator, ctx, value.as_basic_value_enum())
}
/// Returns the [type][ProxyType] of this value.
fn get_type(&self) -> Self::Type;
/// Returns the [base value][Self::Base] of this proxy.
fn as_base_value(&self) -> Self::Base;
}

View File

@ -0,0 +1,202 @@
use inkwell::{
types::{BasicType, BasicTypeEnum, IntType},
values::{IntValue, PointerValue},
AddressSpace,
};
use super::{ArrayLikeValue, NDArrayValue, ProxyValue};
use crate::codegen::{
stmt::gen_if_callback,
types::{
ndarray::{ContiguousNDArrayType, NDArrayType},
structure::StructField,
},
CodeGenContext, CodeGenerator,
};
#[derive(Copy, Clone)]
pub struct ContiguousNDArrayValue<'ctx> {
value: PointerValue<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ContiguousNDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is
/// not an instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
Self { value: ptr, item: dtype, llvm_usize, name }
}
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().ndims
}
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
self.ndims_field().set(ctx, self.as_base_value(), value, self.name);
}
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields().shape
}
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.shape_field().set(ctx, self.as_base_value(), value, self.name);
}
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.shape_field().get(ctx, self.value, self.name)
}
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields().data
}
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.data_field().set(ctx, self.as_base_value(), value, self.name);
}
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.data_field().get(ctx, self.value, self.name)
}
}
impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ContiguousNDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
<Self as ProxyValue<'ctx>>::Type::from_type(
self.as_base_value().get_type(),
self.item,
self.llvm_usize,
)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ContiguousNDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ContiguousNDArrayValue<'ctx>) -> Self {
value.as_base_value()
}
}
impl<'ctx> NDArrayValue<'ctx> {
/// Create a [`ContiguousNDArrayValue`] from the contents of this ndarray.
///
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
///
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the
/// `data` field of the returned [`ContiguousNDArrayValue`] and copy contents of this ndarray to
/// there.
///
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created
/// [`ContiguousNDArrayValue`] will share memory with this ndarray.
pub fn make_contiguous_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ContiguousNDArrayValue<'ctx> {
let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype)
.alloca(generator, ctx, self.name);
// Set ndims and shape.
let ndims = self
.ndims
.map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false));
result.store_ndims(ctx, ndims);
let shape = self.shape();
result.store_shape(ctx, shape.base_ptr(ctx, generator));
gen_if_callback(
generator,
ctx,
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|_, ctx| {
// This ndarray is contiguous.
let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name);
let data = ctx
.builder
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
.unwrap();
result.store_data(ctx, data);
Ok(())
},
|generator, ctx| {
// This ndarray is not contiguous. Do a full-copy on `data`. `make_copy` produces an
// ndarray with contiguous `data`.
let copied_ndarray = self.make_copy(generator, ctx);
let data = copied_ndarray.data().base_ptr(ctx, generator);
let data = ctx
.builder
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
.unwrap();
result.store_data(ctx, data);
Ok(())
},
)
.unwrap();
result
}
/// Create an [`NDArrayValue`] from a [`ContiguousNDArrayValue`].
///
/// The operation is cheap. The newly created [`NDArrayValue`] will share the same memory as the
/// [`ContiguousNDArrayValue`].
///
/// `ndims` has to be provided as [`NDArrayValue`] requires a statically known `ndims` value,
/// despite the fact that the information should be contained within the
/// [`ContiguousNDArrayValue`].
pub fn from_contiguous_ndarray<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
carray: ContiguousNDArrayValue<'ctx>,
ndims: u64,
) -> Self {
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
// Allocate the resulting ndarray.
let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims))
.construct_uninitialized(generator, ctx, carray.name);
// Copy shape and update strides
let shape = carray.load_shape(ctx);
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.set_strides_contiguous(generator, ctx);
// Share data
let data = carray.load_data(ctx);
ndarray.store_data(
ctx,
ctx.builder
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap(),
);
ndarray
}
}

View File

@ -0,0 +1,262 @@
use inkwell::{
types::IntType,
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3parser::ast::{Expr, ExprKind};
use crate::{
codegen::{
irrt,
types::{
ndarray::{NDArrayType, NDIndexType},
structure::StructField,
utils::SliceType,
},
values::{ndarray::NDArrayValue, utils::RustSlice, ProxyValue},
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
/// An IRRT representation of an ndarray subscript index.
#[derive(Copy, Clone)]
pub struct NDIndexValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDIndexValue<'ctx> {
/// Checks whether `value` is an instance of `ndindex`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`NDIndexValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
Self { value: ptr, llvm_usize, name }
}
fn type_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().type_
}
pub fn load_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.type_field().get(ctx, self.value, self.name)
}
pub fn store_type(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
self.type_field().set(ctx, self.value, value, self.name);
}
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields().data
}
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.data_field().get(ctx, self.value, self.name)
}
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.data_field().set(ctx, self.value, value, self.name);
}
}
impl<'ctx> ProxyValue<'ctx> for NDIndexValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDIndexType<'ctx>;
fn get_type(&self) -> Self::Type {
Self::Type::from_type(self.value.get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDIndexValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDIndexValue<'ctx>) -> Self {
value.as_base_value()
}
}
impl<'ctx> NDArrayValue<'ctx> {
/// Get the expected `ndims` after indexing with `indices`.
#[must_use]
fn deduce_ndims_after_indexing_with(&self, indices: &[RustNDIndex<'ctx>]) -> Option<u64> {
let mut ndims = self.ndims?;
for index in indices {
match index {
RustNDIndex::SingleElement(_) => {
ndims -= 1; // Single elements decrements ndims
}
RustNDIndex::NewAxis => {
ndims += 1; // `np.newaxis` / `none` adds a new axis
}
RustNDIndex::Ellipsis | RustNDIndex::Slice(_) => {}
}
}
Some(ndims)
}
/// Index into the ndarray, and return a newly-allocated view on this ndarray.
///
/// This function behaves like NumPy's ndarray indexing, but if the indices index
/// into a single element, an unsized ndarray is returned.
#[must_use]
pub fn index<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
indices: &[RustNDIndex<'ctx>],
) -> Self {
assert!(self.ndims.is_some(), "NDArrayValue::index is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
let dst_ndims = self.deduce_ndims_after_indexing_with(indices);
let dst_ndarray = NDArrayType::new(generator, ctx.ctx, self.dtype, dst_ndims)
.construct_uninitialized(generator, ctx, None);
let indices =
NDIndexType::new(generator, ctx.ctx).construct_ndindices(generator, ctx, indices);
irrt::ndarray::call_nac3_ndarray_index(generator, ctx, indices, *self, dst_ndarray);
dst_ndarray
}
}
/// A convenience enum representing a [`NDIndexValue`].
// TODO: Rename to CTConstNDIndex
#[derive(Debug, Clone)]
pub enum RustNDIndex<'ctx> {
SingleElement(IntValue<'ctx>),
Slice(RustSlice<'ctx>),
NewAxis,
Ellipsis,
}
impl<'ctx> RustNDIndex<'ctx> {
/// Generate LLVM code to transform an ndarray subscript expression to
/// its list of [`RustNDIndex`]
///
/// i.e.,
/// ```python
/// my_ndarray[::3, 1, :2:]
/// ^^^^^^^^^^^ Then these into a three `RustNDIndex`es
/// ```
pub fn from_subscript_expr<G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
subscript: &Expr<Option<Type>>,
) -> Result<Vec<RustNDIndex<'ctx>>, String> {
// Annoying notes about `slice`
// - `my_array[5]`
// - slice is a `Constant`
// - `my_array[:5]`
// - slice is a `Slice`
// - `my_array[:]`
// - slice is a `Slice`, but lower upper step would all be `Option::None`
// - `my_array[:, :]`
// - slice is now a `Tuple` of two `Slice`-s
//
// In summary:
// - when there is a comma "," within [], `slice` will be a `Tuple` of the entries.
// - when there is not comma "," within [] (i.e., just a single entry), `slice` will be that entry itself.
//
// So we first "flatten" out the slice expression
let index_exprs = match &subscript.node {
ExprKind::Tuple { elts, .. } => elts.iter().collect_vec(),
_ => vec![subscript],
};
// Process all index expressions
let mut rust_ndindices: Vec<RustNDIndex> = Vec::with_capacity(index_exprs.len()); // Not using iterators here because `?` is used here.
for index_expr in index_exprs {
// NOTE: Currently nac3core's slices do not have an object representation,
// so the code/implementation looks awkward - we have to do pattern matching on the expression
let ndindex = if let ExprKind::Slice { lower, upper, step } = &index_expr.node {
// Handle slices
let slice = RustSlice::from_slice_expr(generator, ctx, lower, upper, step)?;
RustNDIndex::Slice(slice)
} else {
// Treat and handle everything else as a single element index.
let index = generator.gen_expr(ctx, index_expr)?.unwrap().to_basic_value_enum(
ctx,
generator,
ctx.primitives.int32, // Must be int32, this checks for illegal values
)?;
let index = index.into_int_value();
RustNDIndex::SingleElement(index)
};
rust_ndindices.push(ndindex);
}
Ok(rust_ndindices)
}
/// Get the value to set `NDIndex::type` for this variant.
#[must_use]
pub fn get_type_id(&self) -> u64 {
// Defined in IRRT, must be in sync
match self {
RustNDIndex::SingleElement(_) => 0,
RustNDIndex::Slice(_) => 1,
RustNDIndex::NewAxis => 2,
RustNDIndex::Ellipsis => 3,
}
}
/// Serialize this [`RustNDIndex`] by writing it into an LLVM [`NDIndexValue`].
pub fn write_to_ndindex<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dst_ndindex: NDIndexValue<'ctx>,
) {
let llvm_pi8 = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
// Set `dst_ndindex.type`
dst_ndindex.store_type(ctx, ctx.ctx.i8_type().const_int(self.get_type_id(), false));
// Set `dst_ndindex_ptr->data`
match self {
RustNDIndex::SingleElement(in_index) => {
let index_ptr = ctx.builder.build_alloca(ctx.ctx.i32_type(), "").unwrap();
ctx.builder.build_store(index_ptr, *in_index).unwrap();
dst_ndindex.store_data(
ctx,
ctx.builder.build_pointer_cast(index_ptr, llvm_pi8, "").unwrap(),
);
}
RustNDIndex::Slice(in_rust_slice) => {
let user_slice_ptr =
SliceType::new(ctx.ctx, ctx.ctx.i32_type(), generator.get_size_type(ctx.ctx))
.alloca(generator, ctx, None);
in_rust_slice.write_to_slice(ctx, user_slice_ptr);
dst_ndindex.store_data(
ctx,
ctx.builder.build_pointer_cast(user_slice_ptr.into(), llvm_pi8, "").unwrap(),
);
}
RustNDIndex::NewAxis | RustNDIndex::Ellipsis => {}
}
}
}

View File

@ -0,0 +1,933 @@
use inkwell::{
types::{AnyType, AnyTypeEnum, BasicType, BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
use super::{
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
};
use crate::codegen::{
irrt,
llvm_intrinsics::{call_int_umin, call_memcpy_generic_array},
stmt::gen_for_callback_incrementing,
type_aligned_alloca,
types::{ndarray::NDArrayType, structure::StructField},
CodeGenContext, CodeGenerator,
};
pub use contiguous::*;
pub use indexing::*;
pub use nditer::*;
pub use view::*;
mod contiguous;
mod indexing;
mod nditer;
mod view;
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
NDArrayType::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
ndims: Option<u64>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
NDArrayValue { value: ptr, dtype, ndims, llvm_usize, name }
}
fn ndims_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).ndims
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.ndims_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the number of dimensions `ndims` into this instance.
pub fn store_ndims<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
ndims: IntValue<'ctx>,
) {
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_store(pndims, ndims).unwrap();
}
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
fn itemsize_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).itemsize
}
/// Stores the size of each element `itemsize` into this instance.
pub fn store_itemsize<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
itemsize: IntValue<'ctx>,
) {
debug_assert_eq!(itemsize.get_type(), generator.get_size_type(ctx.ctx));
self.itemsize_field(ctx).set(ctx, self.value, itemsize, self.name);
}
/// Returns the size of each element of this `NDArray` as a value.
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.itemsize_field(ctx).get(ctx, self.value, self.name)
}
fn shape_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).shape
}
/// Returns the double-indirection pointer to the `shape` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.shape_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of dimension sizes `dims` into this instance.
fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
self.shape_field(ctx).set(ctx, self.as_base_value(), dims, self.name);
}
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
pub fn create_shape(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_shape(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
#[must_use]
pub fn shape(&self) -> NDArrayShapeProxy<'ctx, '_> {
NDArrayShapeProxy(self)
}
fn strides_field(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).strides
}
/// Returns the double-indirection pointer to the `strides` array, as if by calling
/// `getelementptr` on the field.
fn ptr_to_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.strides_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of stride sizes `strides` into this instance.
fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, strides: PointerValue<'ctx>) {
self.strides_field(ctx).set(ctx, self.as_base_value(), strides, self.name);
}
/// Convenience method for creating a new array storing the stride with the given `size`.
pub fn create_strides(
&self,
ctx: &CodeGenContext<'ctx, '_>,
llvm_usize: IntType<'ctx>,
size: IntValue<'ctx>,
) {
self.store_strides(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap());
}
/// Returns a proxy object to the field storing the stride of each dimension of this `NDArray`.
#[must_use]
pub fn strides(&self) -> NDArrayStridesProxy<'ctx, '_> {
NDArrayStridesProxy(self)
}
fn data_field(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).data
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.data_field(ctx).ptr_by_gep(ctx, self.value, self.name)
}
/// Stores the array of data elements `data` into this instance.
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
let data = ctx
.builder
.build_bit_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap();
self.data_field(ctx).set(ctx, self.as_base_value(), data.into_pointer_value(), self.name);
}
/// Convenience method for creating a new array storing data elements with the given element
/// type `elem_ty` and `size`.
///
/// The data buffer will be allocated on the stack, and is considered to be owned by this ndarray instance.
///
/// # Safety
///
/// The caller must ensure that `shape` and `itemsize` of this ndarray instance is initialized.
pub unsafe fn create_data<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
let nbytes = self.nbytes(generator, ctx);
let data = type_aligned_alloca(generator, ctx, self.dtype, nbytes, None);
self.store_data(ctx, data);
self.set_strides_contiguous(generator, ctx);
}
/// Returns a proxy object to the field storing the data of this `NDArray`.
#[must_use]
pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> {
NDArrayDataProxy(self)
}
/// Copy shape dimensions from an array.
pub fn copy_shape_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
shape: PointerValue<'ctx>,
) {
let num_items = self.load_ndims(ctx);
call_memcpy_generic_array(
ctx,
self.shape().base_ptr(ctx, generator),
shape,
num_items,
ctx.ctx.bool_type().const_zero(),
);
}
/// Copy shape dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_shape_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
) {
if self.ndims.is_some() && src_ndarray.ndims.is_some() {
assert_eq!(self.ndims, src_ndarray.ndims);
} else {
let self_ndims = self.load_ndims(ctx);
let src_ndims = src_ndarray.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(
IntPredicate::EQ,
self_ndims,
src_ndims,
""
).unwrap(),
"0:AssertionError",
"NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})",
[Some(self_ndims), Some(src_ndims), None],
ctx.current_loc
);
}
let src_shape = src_ndarray.shape().base_ptr(ctx, generator);
self.copy_shape_from_array(generator, ctx, src_shape);
}
/// Copy strides dimensions from an array.
pub fn copy_strides_from_array<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
strides: PointerValue<'ctx>,
) {
let num_items = self.load_ndims(ctx);
call_memcpy_generic_array(
ctx,
self.strides().base_ptr(ctx, generator),
strides,
num_items,
ctx.ctx.bool_type().const_zero(),
);
}
/// Copy strides dimensions from an ndarray.
/// Panics if `ndims` mismatches.
pub fn copy_strides_from_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src_ndarray: NDArrayValue<'ctx>,
) {
if self.ndims.is_some() && src_ndarray.ndims.is_some() {
assert_eq!(self.ndims, src_ndarray.ndims);
} else {
let self_ndims = self.load_ndims(ctx);
let src_ndims = src_ndarray.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(
IntPredicate::EQ,
self_ndims,
src_ndims,
""
).unwrap(),
"0:AssertionError",
"NDArrayValue::copy_shape_from_ndarray: Expected self.ndims ({0}) == src_ndarray.ndims ({1})",
[Some(self_ndims), Some(src_ndims), None],
ctx.current_loc
);
}
let src_strides = src_ndarray.strides().base_ptr(ctx, generator);
self.copy_strides_from_array(generator, ctx, src_strides);
}
/// Get the `np.size()` of this ndarray.
pub fn size<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_size(generator, ctx, *self)
}
/// Get the `ndarray.nbytes` of this ndarray.
pub fn nbytes<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_nbytes(generator, ctx, *self)
}
/// Get the `len()` of this ndarray.
pub fn len<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self)
}
/// Check if this ndarray is C-contiguous.
///
/// See NumPy's `flags["C_CONTIGUOUS"]`: <https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html#numpy.ndarray.flags>
pub fn is_c_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_is_c_contiguous(generator, ctx, *self)
}
/// Call [`call_nac3_ndarray_set_strides_by_shape`] on this ndarray to update `strides`.
///
/// Update the ndarray's strides to make the ndarray contiguous.
pub fn set_strides_contiguous<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) {
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
}
#[must_use]
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
let clone = if self.ndims.is_some() {
self.get_type().construct_uninitialized(generator, ctx, None)
} else {
self.get_type().construct_dyn_ndims(generator, ctx, self.load_ndims(ctx), None)
};
let shape = self.shape();
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
unsafe { clone.create_data(generator, ctx) };
clone.copy_data_from(generator, ctx, *self);
clone
}
/// Copy data from another ndarray.
///
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
/// do not matter. The copying order is determined by how their flattened views look.
///
/// Panics if the `dtype`s of ndarrays are different.
pub fn copy_data_from<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
src: NDArrayValue<'ctx>,
) {
assert_eq!(self.dtype, src.dtype, "self and src dtype should match");
irrt::ndarray::call_nac3_ndarray_copy_data(generator, ctx, src, *self);
}
/// Returns true if this ndarray is unsized - `ndims == 0` and only contains a scalar.
#[must_use]
pub fn is_unsized(&self) -> Option<bool> {
self.ndims.map(|ndims| ndims == 0)
}
/// If this ndarray is unsized, return its sole value as an [`AnyObject`].
/// Otherwise, do nothing and return the ndarray itself.
// TODO: Rename to get_unsized_element
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarOrNDArray<'ctx> {
let Some(is_unsized) = self.is_unsized() else { todo!() };
if is_unsized {
// NOTE: `np.size(self) == 0` here is never possible.
let zero = generator.get_size_type(ctx.ctx).const_zero();
let value = unsafe { self.data().get_unchecked(ctx, generator, &zero, None) };
ScalarOrNDArray::Scalar(value)
} else {
ScalarOrNDArray::NDArray(*self)
}
}
}
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
NDArrayType::from_type(
self.as_base_value().get_type(),
self.dtype,
self.ndims,
self.llvm_usize,
)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDArrayValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing the `shape` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayShapeProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayShapeProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.shape().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
self.0.shape_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayShapeProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `strides` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayStridesProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayStridesProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> AnyTypeEnum<'ctx> {
self.0.strides().base_ptr(ctx, generator).get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
self.0.strides_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> IntValue<'ctx> {
self.0.load_ndims(ctx)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {}
impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn downcast_to_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> IntValue<'ctx> {
value.into_int_value()
}
}
impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayStridesProxy<'ctx, '_> {
fn upcast_from_type(
&self,
_: &mut CodeGenContext<'ctx, '_>,
value: IntValue<'ctx>,
) -> BasicValueEnum<'ctx> {
value.into()
}
}
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>);
impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
fn element_type<G: CodeGenerator + ?Sized>(
&self,
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.dtype.as_any_type_enum()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
self.0.data_field(ctx).get(ctx, self.0.as_base_value(), self.0.name)
}
fn size<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
generator: &G,
) -> IntValue<'ctx> {
irrt::ndarray::call_ndarray_calc_size(
generator,
ctx,
&self.as_slice_value(ctx, generator),
(None, None),
)
}
}
impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let sizeof_elem = ctx
.builder
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
idx.get_type(),
"",
)
.unwrap();
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[idx],
name.unwrap_or_default(),
)
.unwrap()
};
// Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately
// load/store.
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let data_sz = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap();
ctx.make_assert(
generator,
in_range,
"0:IndexError",
"index {0} is out of bounds with size {1}",
[Some(*idx), Some(self.0.load_ndims(ctx)), None],
ctx.current_loc,
);
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) };
// Current implementation is transparent - The returned pointer type is
// already cast into the expected type, allowing for immediately
// load/store.
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
.get_type()
.get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
);
let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices);
let sizeof_elem = ctx
.builder
.build_int_truncate_or_bit_cast(
self.element_type(ctx, generator).size_of().unwrap(),
index.get_type(),
"",
)
.unwrap();
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
let ptr = unsafe {
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
)
.unwrap()
};
// TODO: Current implementation is transparent
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G,
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.size(ctx, generator);
let nidx_leq_ndims = ctx
.builder
.build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "")
.unwrap();
ctx.make_assert(
generator,
nidx_leq_ndims,
"0:IndexError",
"invalid index to scalar variable",
[None, None, None],
ctx.current_loc,
);
let indices_len = indices.size(ctx, generator);
let ndarray_len = self.0.load_ndims(ctx);
let len = call_int_umin(ctx, indices_len, ndarray_len, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(len, false),
|generator, ctx, _, i| {
let (dim_idx, dim_sz) = unsafe {
(
indices.get_unchecked(ctx, generator, &i, None).into_int_value(),
self.0.shape().get_typed_unchecked(ctx, generator, &i, None),
)
};
let dim_idx = ctx
.builder
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
.unwrap();
let dim_lt =
ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap();
ctx.make_assert(
generator,
dim_lt,
"0:IndexError",
"index {0} is out of bounds for axis 0 with size {1}",
[Some(dim_idx), Some(dim_sz), None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let ptr = unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) };
// TODO: Current implementation is transparent
ctx.builder
.build_pointer_cast(
ptr,
BasicTypeEnum::try_from(self.element_type(ctx, generator))
.unwrap()
.ptr_type(AddressSpace::default()),
"",
)
.unwrap()
}
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}
/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust.
///
/// This function is used generating strides for globally defined contiguous ndarrays.
#[must_use]
pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec<u64> {
let mut strides = Vec::with_capacity(ndims as usize);
let mut stride_product = 1u64;
for i in 0..ndims {
let axis = ndims - i - 1;
strides[axis as usize] = stride_product * itemsize;
stride_product *= shape[axis as usize];
}
strides
}
/// A convenience enum for implementing functions that acts on scalars or ndarrays or both.
#[derive(Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(BasicValueEnum<'ctx>),
NDArray(NDArrayValue<'ctx>),
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// Get the underlying [`BasicValueEnum<'ctx>`] of this [`ScalarOrNDArray`].
#[must_use]
pub fn to_basic_value_enum(self) -> BasicValueEnum<'ctx> {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar,
ScalarOrNDArray::NDArray(ndarray) => ndarray.as_base_value().into(),
}
}
}

View File

@ -0,0 +1,176 @@
use inkwell::{
types::{BasicType, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
use super::{NDArrayValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator};
use crate::codegen::{
irrt,
stmt::{gen_for_callback, BreakContinueHooks},
types::{ndarray::NDIterType, structure::StructField},
values::{ArraySliceValue, TypedArrayLikeAdapter},
CodeGenContext, CodeGenerator,
};
#[derive(Copy, Clone)]
pub struct NDIterValue<'ctx> {
value: PointerValue<'ctx>,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDIterValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
<Self as ProxyValue>::Type::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
parent: NDArrayValue<'ctx>,
indices: ArraySliceValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
Self { value: ptr, parent, indices, llvm_usize, name }
}
/// Is the current iteration valid?
///
/// If true, then `element`, `indices` and `nth` contain details about the current element.
///
/// If `ndarray` is unsized, this returns true only for the first iteration.
/// If `ndarray` is 0-sized, this always returns false.
#[must_use]
pub fn has_element<G: CodeGenerator + ?Sized>(
&self,
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
) -> IntValue<'ctx> {
irrt::ndarray::call_nac3_nditer_has_element(generator, ctx, *self)
}
/// Go to the next element. If `has_element()` is false, then this has undefined behavior.
///
/// If `ndarray` is unsized, this can only be called once.
/// If `ndarray` is 0-sized, this can never be called.
pub fn next<G: CodeGenerator + ?Sized>(&self, generator: &G, ctx: &CodeGenContext<'ctx, '_>) {
irrt::ndarray::call_nac3_nditer_next(generator, ctx, *self);
}
fn element(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).element
}
/// Get pointer to the current element.
#[must_use]
pub fn get_pointer(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let elem_ty = self.parent.dtype;
let p = self.element(ctx).get(ctx, self.as_base_value(), None);
ctx.builder
.build_pointer_cast(p, elem_ty.ptr_type(AddressSpace::default()), "element")
.unwrap()
}
/// Get the value of the current element.
#[must_use]
pub fn get_scalar(&self, ctx: &CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let p = self.get_pointer(ctx);
ctx.builder.build_load(p, "value").unwrap()
}
fn nth(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields(ctx.ctx).nth
}
/// Get the index of the current element if this ndarray were a flat ndarray.
#[must_use]
pub fn get_index(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.nth(ctx).get(ctx, self.as_base_value(), None)
}
/// Get the indices of the current element.
#[must_use]
pub fn get_indices(
&'ctx self,
) -> impl TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> + TypedArrayLikeMutator<'ctx, IntValue<'ctx>>
{
TypedArrayLikeAdapter::from(
self.indices,
Box::new(|ctx, val| {
ctx.builder
.build_int_z_extend_or_bit_cast(val.into_int_value(), self.llvm_usize, "")
.unwrap()
}),
Box::new(|_, val| val.into()),
)
}
}
impl<'ctx> ProxyValue<'ctx> for NDIterValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDIterType<'ctx>;
fn get_type(&self) -> Self::Type {
NDIterType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDIterValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDIterValue<'ctx>) -> Self {
value.as_base_value()
}
}
impl<'ctx> NDArrayValue<'ctx> {
/// Iterate through every element in the ndarray.
///
/// `body` has access to [`BreakContinueHooks`] to short-circuit and [`NDIterValue`] to
/// get properties of the current iteration (e.g., the current element, indices, etc.)
pub fn foreach<'a, G, F>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
body: F,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
F: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
BreakContinueHooks<'ctx>,
NDIterValue<'ctx>,
) -> Result<(), String>,
{
gen_for_callback(
generator,
ctx,
Some("ndarray_foreach"),
|generator, ctx| {
Ok(NDIterType::new(generator, ctx.ctx).construct(generator, ctx, *self))
},
|generator, ctx, nditer| Ok(nditer.has_element(generator, ctx)),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
)
}
}

View File

@ -0,0 +1,36 @@
use std::iter::{once, repeat_n};
use itertools::Itertools;
use crate::codegen::{
values::ndarray::{NDArrayValue, RustNDIndex},
CodeGenContext, CodeGenerator,
};
impl<'ctx> NDArrayValue<'ctx> {
/// Make sure the ndarray is at least `ndmin`-dimensional.
///
/// If this ndarray's `ndims` is less than `ndmin`, a view is created on this with 1s prepended
/// to the shape. Otherwise, this function does nothing and return this ndarray.
#[must_use]
pub fn atleast_nd<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndmin: u64,
) -> Self {
assert!(self.ndims.is_some(), "NDArrayValue::atleast_nd is only supported for instances with compile-time known ndims (self.ndims = Some(...))");
let ndims = self.ndims.unwrap();
if ndims < ndmin {
// Extend the dimensions with np.newaxis.
let indices = repeat_n(RustNDIndex::NewAxis, (ndmin - ndims) as usize)
.chain(once(RustNDIndex::Ellipsis))
.collect_vec();
self.index(generator, ctx, &indices)
} else {
*self
}
}
}

View File

@ -0,0 +1,153 @@
use inkwell::values::{BasicValueEnum, IntValue, PointerValue};
use super::ProxyValue;
use crate::codegen::{types::RangeType, CodeGenContext};
/// Proxy type for accessing a `range` value in LLVM.
#[derive(Copy, Clone)]
pub struct RangeValue<'ctx> {
value: PointerValue<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_representable(value: PointerValue<'ctx>) -> Result<(), String> {
RangeType::is_representable(value.get_type())
}
/// Creates an [`RangeValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_representable(ptr).is_ok());
RangeValue { value: ptr, name }
}
fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
var_name.as_str(),
)
.unwrap()
}
}
fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
var_name.as_str(),
)
.unwrap()
}
}
fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default();
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the `start` value into this instance.
pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) {
debug_assert_eq!(start.get_type().get_bit_width(), 32);
let pstart = self.ptr_to_start(ctx);
ctx.builder.build_store(pstart, start).unwrap();
}
/// Returns the `start` value of this `range`.
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstart = self.ptr_to_start(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.start")))
.unwrap_or_default();
ctx.builder
.build_load(pstart, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
/// Stores the `end` value into this instance.
pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) {
debug_assert_eq!(end.get_type().get_bit_width(), 32);
let pend = self.ptr_to_end(ctx);
ctx.builder.build_store(pend, end).unwrap();
}
/// Returns the `end` value of this `range`.
pub fn load_end(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pend = self.ptr_to_end(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.end")))
.unwrap_or_default();
ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap()
}
/// Stores the `step` value into this instance.
pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) {
debug_assert_eq!(step.get_type().get_bit_width(), 32);
let pstep = self.ptr_to_step(ctx);
ctx.builder.build_store(pstep, step).unwrap();
}
/// Returns the `step` value of this `range`.
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>, name: Option<&str>) -> IntValue<'ctx> {
let pstep = self.ptr_to_step(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.name.map(|v| format!("{v}.step")))
.unwrap_or_default();
ctx.builder
.build_load(pstep, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
}
impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = RangeType<'ctx>;
fn get_type(&self) -> Self::Type {
RangeType::from_type(self.value.get_type())
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
fn from(value: RangeValue<'ctx>) -> Self {
value.as_base_value()
}
}

View File

@ -0,0 +1,3 @@
pub use slice::*;
mod slice;

View File

@ -0,0 +1,231 @@
use inkwell::{
types::IntType,
values::{IntValue, PointerValue},
};
use nac3parser::ast::Expr;
use crate::{
codegen::{
types::{structure::StructField, utils::SliceType},
values::ProxyValue,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
/// An IRRT representation of an (unresolved) slice.
#[derive(Copy, Clone)]
pub struct SliceValue<'ctx> {
value: PointerValue<'ctx>,
int_ty: IntType<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> SliceValue<'ctx> {
/// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is
/// not an instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`SliceValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
int_ty: IntType<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
Self { value: ptr, int_ty, llvm_usize, name }
}
fn start_defined_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().start_defined
}
pub fn load_start_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.start_defined_field().get(ctx, self.value, self.name)
}
fn start_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().start
}
pub fn load_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.start_field().get(ctx, self.value, self.name)
}
pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) {
match value {
Some(start) => {
self.start_defined_field().set(
ctx,
self.value,
ctx.ctx.bool_type().const_all_ones(),
self.name,
);
self.start_field().set(ctx, self.value, start, self.name);
}
None => self.start_defined_field().set(
ctx,
self.value,
ctx.ctx.bool_type().const_zero(),
self.name,
),
}
}
fn stop_defined_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().stop_defined
}
pub fn load_stop_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.stop_defined_field().get(ctx, self.value, self.name)
}
fn stop_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().stop
}
pub fn load_stop(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.stop_field().get(ctx, self.value, self.name)
}
pub fn store_stop(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) {
match value {
Some(stop) => {
self.stop_defined_field().set(
ctx,
self.value,
ctx.ctx.bool_type().const_all_ones(),
self.name,
);
self.stop_field().set(ctx, self.value, stop, self.name);
}
None => self.stop_defined_field().set(
ctx,
self.value,
ctx.ctx.bool_type().const_zero(),
self.name,
),
}
}
fn step_defined_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().step_defined
}
pub fn load_step_defined(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.step_defined_field().get(ctx, self.value, self.name)
}
fn step_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().step
}
pub fn load_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
self.step_field().get(ctx, self.value, self.name)
}
pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, value: Option<IntValue<'ctx>>) {
match value {
Some(step) => {
self.step_defined_field().set(
ctx,
self.value,
ctx.ctx.bool_type().const_all_ones(),
self.name,
);
self.step_field().set(ctx, self.value, step, self.name);
}
None => self.step_defined_field().set(
ctx,
self.value,
ctx.ctx.bool_type().const_zero(),
self.name,
),
}
}
}
impl<'ctx> ProxyValue<'ctx> for SliceValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = SliceType<'ctx>;
fn get_type(&self) -> Self::Type {
Self::Type::from_type(self.value.get_type(), self.int_ty, self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<SliceValue<'ctx>> for PointerValue<'ctx> {
fn from(value: SliceValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// A slice represented in compile-time by `start`, `stop` and `step`, all held as LLVM values.
// TODO: Rename this to CTConstSlice
#[derive(Debug, Copy, Clone)]
pub struct RustSlice<'ctx> {
int_ty: IntType<'ctx>,
start: Option<IntValue<'ctx>>,
stop: Option<IntValue<'ctx>>,
step: Option<IntValue<'ctx>>,
}
impl<'ctx> RustSlice<'ctx> {
/// Generate LLVM IR for an [`ExprKind::Slice`] and convert it into a [`RustSlice`].
#[allow(clippy::type_complexity)]
pub fn from_slice_expr<G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lower: &Option<Box<Expr<Option<Type>>>>,
upper: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>,
) -> Result<RustSlice<'ctx>, String> {
let mut value_mapper = |value_expr: &Option<Box<Expr<Option<Type>>>>| -> Result<_, String> {
Ok(match value_expr {
None => None,
Some(value_expr) => {
let value_expr = generator
.gen_expr(ctx, value_expr)?
.map(|value| {
value.to_basic_value_enum(ctx, generator, ctx.primitives.int32)
})
.unwrap()?;
Some(value_expr.into_int_value())
}
})
};
let start = value_mapper(lower)?;
let stop = value_mapper(upper)?;
let step = value_mapper(step)?;
Ok(RustSlice { int_ty: ctx.ctx.i32_type(), start, stop, step })
}
/// Write the contents to an LLVM [`SliceValue`].
pub fn write_to_slice(&self, ctx: &CodeGenContext<'ctx, '_>, dst_slice_ptr: SliceValue<'ctx>) {
assert_eq!(self.int_ty, dst_slice_ptr.int_ty);
dst_slice_ptr.store_start(ctx, self.start);
dst_slice_ptr.store_stop(ctx, self.stop);
dst_slice_ptr.store_step(ctx, self.step);
}
}

View File

@ -1,7 +1,25 @@
#![warn(clippy::all)] #![deny(future_incompatible, let_underscore, nonstandard_style, clippy::all)]
#![allow(dead_code)] #![warn(clippy::pedantic)]
#![allow(
dead_code,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::*
pub use inkwell;
pub use nac3parser;
pub mod codegen; pub mod codegen;
pub mod symbol_resolver; pub mod symbol_resolver;
pub mod toplevel; pub mod toplevel;
pub mod typecheck; pub mod typecheck;
extern crate self as nac3core;

View File

@ -1,24 +1,24 @@
use std::fmt::Debug; use std::{
use std::sync::Arc; collections::{HashMap, HashSet},
use std::{collections::HashMap, collections::HashSet, fmt::Display}; fmt::{Debug, Display},
use std::rc::Rc; rc::Rc,
sync::Arc,
use crate::typecheck::typedef::TypeEnum;
use crate::{
codegen::CodeGenContext,
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation},
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use crate::{ use crate::{
codegen::CodeGenerator, codegen::{CodeGenContext, CodeGenerator},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, Unifier}, typedef::{Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue { pub enum SymbolValue {
@ -43,7 +43,7 @@ impl SymbolValue {
constant: &Constant, constant: &Constant,
expected_ty: Type, expected_ty: Type,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
unifier: &mut Unifier unifier: &mut Unifier,
) -> Result<Self, String> { ) -> Result<Self, String> {
match constant { match constant {
Constant::None => { Constant::None => {
@ -66,35 +66,30 @@ impl SymbolValue {
} else { } else {
Err(format!("Expected {expected_ty:?}, but got str")) Err(format!("Expected {expected_ty:?}, but got str"))
} }
}, }
Constant::Int(i) => { Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) { if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i) i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
.map(SymbolValue::I32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) { } else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i) i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
.map(SymbolValue::I64)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) { } else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i) u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
.map(SymbolValue::U32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) { } else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i) u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
.map(SymbolValue::U64)
.map_err(|e| e.to_string())
} else { } else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty))) Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
} }
} }
Constant::Tuple(t) => { Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty); let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else { let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else {
return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name())) return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
}; };
assert_eq!(ty.len(), t.len()); assert!(*is_vararg_ctx || ty.len() == t.len());
let elems = t let elems = t
.iter() .iter()
@ -109,7 +104,45 @@ impl SymbolValue {
} else { } else {
Err(format!("Expected {expected_ty:?}, but got float")) Err(format!("Expected {expected_ty:?}, but got float"))
} }
}, }
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i)
.map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} else {
u32::try_from(i)
.map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
}
}
Constant::Tuple(t) => {
let elems = t
.iter()
.map(Self::from_constant_inferred)
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
_ => Err(format!("Unsupported value type {constant:?}")), _ => Err(format!("Unsupported value type {constant:?}")),
} }
} }
@ -125,28 +158,27 @@ impl SymbolValue {
SymbolValue::Double(_) => primitives.float, SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool, SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
.iter() unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false })
.map(|v| v.get_type(primitives, unifier))
.collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple {
ty: vs_tys,
})
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
} }
} }
/// Returns the [`TypeAnnotation`] representing the data type of this value. /// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { pub fn get_type_annotation(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
match self { match self {
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool), SymbolValue::Bool(..)
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float), | SymbolValue::Double(..)
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32), | SymbolValue::I32(..)
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64), | SymbolValue::I64(..)
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32), | SymbolValue::U32(..)
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64), | SymbolValue::U64(..)
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str), | SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs let vs_tys = vs
.iter() .iter()
@ -155,13 +187,13 @@ impl SymbolValue {
TypeAnnotation::Tuple(vs_tys) TypeAnnotation::Tuple(vs_tys)
} }
SymbolValue::OptionNone => TypeAnnotation::CustomClass { SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier), id: primitives.option.obj_id(unifier).unwrap(),
params: Vec::default(), params: Vec::default(),
}, },
SymbolValue::OptionSome(v) => { SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier); let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier), id: primitives.option.obj_id(unifier).unwrap(),
params: vec![ty], params: vec![ty],
} }
} }
@ -169,7 +201,11 @@ impl SymbolValue {
} }
/// Returns the [`TypeEnum`] representing the data type of this value. /// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> { pub fn get_type_enum(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier); let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty) unifier.get_ty(ty)
} }
@ -200,6 +236,38 @@ impl Display for SymbolValue {
} }
} }
impl TryFrom<SymbolValue> for u64 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
/// numeric or if the value cannot be converted into a `u64` without overflow.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::U64(v) => Ok(v),
_ => Err(()),
}
}
}
impl TryFrom<SymbolValue> for i128 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
/// numeric.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(i128::from(v)),
SymbolValue::I64(v) => Ok(i128::from(v)),
SymbolValue::U32(v) => Ok(i128::from(v)),
SymbolValue::U64(v) => Ok(i128::from(v)),
_ => Err(()),
}
}
}
pub trait StaticValue { pub trait StaticValue {
/// Returns a unique identifier for this value. /// Returns a unique identifier for this value.
fn get_unique_identifier(&self) -> u64; fn get_unique_identifier(&self) -> u64;
@ -232,10 +300,10 @@ pub trait StaticValue {
#[derive(Clone)] #[derive(Clone)]
pub enum ValueEnum<'ctx> { pub enum ValueEnum<'ctx> {
/// [ValueEnum] representing a static value. /// [`ValueEnum`] representing a static value.
Static(Arc<dyn StaticValue + Send + Sync>), Static(Arc<dyn StaticValue + Send + Sync>),
/// [ValueEnum] representing a dynamic value. /// [`ValueEnum`] representing a dynamic value.
Dynamic(BasicValueEnum<'ctx>), Dynamic(BasicValueEnum<'ctx>),
} }
@ -270,7 +338,6 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
} }
impl<'ctx> ValueEnum<'ctx> { impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`]. /// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>( pub fn to_basic_value_enum<'a>(
self, self,
@ -302,6 +369,7 @@ pub trait SymbolResolver {
&self, &self,
str: StrRef, str: StrRef,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>>; ) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>; fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
@ -312,7 +380,7 @@ pub trait SymbolResolver {
&self, &self,
_unifier: &mut Unifier, _unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>], _top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore _primitives: &PrimitiveStore,
) -> Result<(), String> { ) -> Result<(), String> {
Ok(()) Ok(())
} }
@ -325,12 +393,12 @@ thread_local! {
"float".into(), "float".into(),
"bool".into(), "bool".into(),
"virtual".into(), "virtual".into(),
"list".into(),
"tuple".into(), "tuple".into(),
"str".into(), "str".into(),
"Exception".into(), "Exception".into(),
"uint32".into(), "uint32".into(),
"uint64".into(), "uint64".into(),
"Literal".into(),
]; ];
} }
@ -349,12 +417,12 @@ pub fn parse_type_annotation<T>(
let float_id = ids[2]; let float_id = ids[2];
let bool_id = ids[3]; let bool_id = ids[3];
let virtual_id = ids[4]; let virtual_id = ids[4];
let list_id = ids[5]; let tuple_id = ids[5];
let tuple_id = ids[6]; let str_id = ids[6];
let str_id = ids[7]; let exn_id = ids[7];
let exn_id = ids[8]; let uint32_id = ids[8];
let uint32_id = ids[9]; let uint64_id = ids[9];
let uint64_id = ids[10]; let literal_id = ids[10];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id { if *id == int32_id {
@ -379,40 +447,29 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() { if !type_vars.is_empty() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "Unexpected number of type parameters: expected {} but got 0",
"Unexpected number of type parameters: expected {} but got 0", type_vars.len()
type_vars.len() )]));
),
]))
} }
let fields = chain( let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))), fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))), methods.iter().map(|(k, v, _)| (*k, (*v, false))),
) )
.collect(); .collect();
Ok(unifier.add_ty(TypeEnum::TObj { Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
obj_id,
fields,
params: HashMap::default(),
}))
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
format!("Cannot use function name as type at {loc}"),
]))
} }
} else { } else {
let ty = resolver let ty =
.get_symbol_type(unifier, top_level_defs, primitives, *id) resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
.map_err(|e| HashSet::from([ |e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
format!("Unknown type annotation at {loc}: {e}"), )?;
]))?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty) Ok(ty)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
format!("Unknown type annotation {id} at {loc}"),
]))
} }
} }
} }
@ -422,9 +479,6 @@ pub fn parse_type_annotation<T>(
if *id == virtual_id { if *id == virtual_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == list_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
} else if *id == tuple_id { } else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node { if let Tuple { elts, .. } = &slice.node {
let ty = elts let ty = elts
@ -433,12 +487,33 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt) parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty })) Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["Expected multiple elements for tuple".into()]))
"Expected multiple elements for tuple".into()
]))
} }
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([format!(
"Expected literal in type argument for Literal at {}",
elt.location
)])),
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}
.into_iter()
.flatten()
.collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else { } else {
let types = if let Tuple { elts, .. } = &slice.node { let types = if let Tuple { elts, .. } = &slice.node {
elts.iter() elts.iter()
@ -454,15 +529,13 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() { if types.len() != type_vars.len() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "Unexpected number of type parameters: expected {} but got {}",
"Unexpected number of type parameters: expected {} but got {}", type_vars.len(),
type_vars.len(), types.len()
types.len() )]));
),
]))
} }
let mut subst = HashMap::new(); let mut subst = VarMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) { for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id *id
@ -484,9 +557,7 @@ pub fn parse_type_annotation<T>(
})); }));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["Cannot use function name as type".into()]))
"Cannot use function name as type".into(),
]))
} }
} }
}; };
@ -497,14 +568,13 @@ pub fn parse_type_annotation<T>(
if let Name { id, .. } = &value.node { if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier) subscript_name_handle(id, slice, unifier)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
format!("unsupported type expression at {}", expr.location),
]))
} }
} }
_ => Err(HashSet::from([ Constant { value, .. } => SymbolValue::from_constant_inferred(value)
format!("unsupported type expression at {}", expr.location), .map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
])), .map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
} }
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -6,53 +6,64 @@ use std::{
sync::Arc, sync::Arc,
}; };
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
use crate::{
codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
};
use inkwell::values::BasicValueEnum; use inkwell::values::BasicValueEnum;
use itertools::{izip, Itertools}; use itertools::Itertools;
use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] use nac3parser::ast::{self, Expr, Location, Stmt, StrRef};
pub struct DefinitionId(pub usize);
use crate::{
codegen::{CodeGenContext, CodeGenerator},
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{
CallId, FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, TypeVarId, Unifier,
VarMap,
},
},
};
use composer::*;
use type_annotation::*;
pub mod builtins; pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod type_annotation; pub mod numpy;
use composer::*;
use type_annotation::*;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
pub mod type_annotation;
type GenCallCallback = Box< #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
dyn for<'ctx, 'a> Fn( pub struct DefinitionId(pub usize);
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, ValueEnum<'ctx>)>, type GenCallCallback = dyn for<'ctx, 'a> Fn(
(&FunSignature, DefinitionId), &mut CodeGenContext<'ctx, 'a>,
Vec<(Option<StrRef>, ValueEnum<'ctx>)>, Option<(Type, ValueEnum<'ctx>)>,
&mut dyn CodeGenerator, (&FunSignature, DefinitionId),
) -> Result<Option<BasicValueEnum<'ctx>>, String> Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
+ Send &mut dyn CodeGenerator,
+ Sync, ) -> Result<Option<BasicValueEnum<'ctx>>, String>
>; + Send
+ Sync;
pub struct GenCall { pub struct GenCall {
fp: GenCallCallback, fp: Box<GenCallCallback>,
} }
impl GenCall { impl GenCall {
#[must_use] #[must_use]
pub fn new(fp: GenCallCallback) -> GenCall { pub fn new(fp: Box<GenCallCallback>) -> GenCall {
GenCall { fp } GenCall { fp }
} }
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`.
#[must_use]
pub fn create_dummy(reason: String) -> GenCall {
Self::new(Box::new(move |_, _, _, _, _| unreachable!("{reason}")))
}
pub fn run<'ctx>( pub fn run<'ctx>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -75,7 +86,7 @@ impl Debug for GenCall {
pub struct FunInstance { pub struct FunInstance {
pub body: Arc<Vec<Stmt<Option<Type>>>>, pub body: Arc<Vec<Stmt<Option<Type>>>>,
pub calls: Arc<HashMap<CodeLocation, CallId>>, pub calls: Arc<HashMap<CodeLocation, CallId>>,
pub subst: HashMap<u32, Type>, pub subst: VarMap,
pub unifier_id: usize, pub unifier_id: usize,
} }
@ -84,7 +95,7 @@ pub enum TopLevelDef {
Class { Class {
/// Name for error messages and symbols. /// Name for error messages and symbols.
name: StrRef, name: StrRef,
/// Object ID used for [TypeEnum]. /// Object ID used for [`TypeEnum`].
object_id: DefinitionId, object_id: DefinitionId,
/// type variables bounded to the class. /// type variables bounded to the class.
type_vars: Vec<Type>, type_vars: Vec<Type>,
@ -92,6 +103,10 @@ pub enum TopLevelDef {
/// ///
/// Name and type is mutable. /// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>, fields: Vec<(StrRef, Type, bool)>,
/// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition. /// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>, methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself. /// Ancestor classes, including itself.
@ -111,18 +126,18 @@ pub enum TopLevelDef {
/// Function signature. /// Function signature.
signature: Type, signature: Type,
/// Instantiated type variable IDs. /// Instantiated type variable IDs.
var_id: Vec<u32>, var_id: Vec<TypeVarId>,
/// Function instance to symbol mapping /// Function instance to symbol mapping
/// ///
/// * Key: String representation of type variable values, sorted by variable ID in ascending /// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. /// order, including type variables associated with the class.
/// * Value: Function symbol name. /// * Value: Function symbol name.
instance_to_symbol: HashMap<String, String>, instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping /// Function instances to annotated AST mapping
/// ///
/// * Key: String representation of type variable values, sorted by variable ID in ascending /// * Key: String representation of type variable values, sorted by variable ID in ascending
/// order, including type variables associated with the class. Excluding rigid type /// order, including type variables associated with the class. Excluding rigid type
/// variables. /// variables.
/// ///
/// Rigid type variables that would be substituted when the function is instantiated. /// Rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, FunInstance>, instance_to_stmt: HashMap<String, FunInstance>,
@ -133,6 +148,25 @@ pub enum TopLevelDef {
/// Definition location. /// Definition location.
loc: Option<Location>, loc: Option<Location>,
}, },
Variable {
/// Qualified name of the global variable, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Type of the global variable.
ty: Type,
/// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
} }
pub struct TopLevelContext { pub struct TopLevelContext {

Some files were not shown because too many files have changed in this diff Show More