forked from M-Labs/nac3
1
0
Fork 0

Compare commits

..

239 Commits

Author SHA1 Message Date
lyken 87d2a4ed59 WIP 2024-07-10 17:27:10 +08:00
lyken 9aae290727 core: irrt general numpy broadcasting 2024-07-10 17:05:01 +08:00
lyken d18c769cdc core: irrt general numpy slicing 2024-07-10 14:05:08 +08:00
lyken f41f06aec7 core: more irrt 2024-07-10 11:56:31 +08:00
lyken 1303265785 core: build.rs rewrite regex to capture `= type` 2024-07-10 10:17:45 +08:00
lyken e9cf6ce1e5 core: move irrt c++ sources to /nac3core/irrt 2024-07-10 10:17:45 +08:00
lyken bc91ab9b13 core: IRRT -Werror=return-type 2024-07-10 10:17:43 +08:00
lyken 1e06a3d199 core: add irrt_test 2024-07-10 10:11:07 +08:00
lyken 87511ac749 core: comment out numpy 2024-07-10 10:05:07 +08:00
Sebastien Bourdeauducq d658d9b00e update dependencies, Python 3.12 on Linux 2024-07-09 23:56:12 +08:00
abdul124 eeb474f9e6 core: reduce code duplication in codegen/extern_fns (#453)
Used macros to reduce code duplication in `codegen/extern_fns`

Reviewed-on: M-Labs/nac3#453
Co-authored-by: abdul124 <ar@m-labs.hk>
Co-committed-by: abdul124 <ar@m-labs.hk>
2024-07-09 16:31:08 +08:00
abdul124 88b72af2d1 core/llvm_intrinsic: improve macro name and comments 2024-07-09 16:30:32 +08:00
abdul124 b73f6c4d68 core: reduce code duplication in codegen/llvm_intrinsic 2024-07-09 16:30:32 +08:00
David Mak f47cdec650 standalone: Fix output format of output_range 2024-07-09 13:55:48 +08:00
David Mak d656880e44 standalone: Fix missing implementation for output_range 2024-07-09 13:53:50 +08:00
David Mak a91602915a core: Fix missing fields in range type 2024-07-09 13:53:50 +08:00
David Mak 1c56005a01 core: Reformat and modernize irrt.cpp
- Use anon namespace instead of static
- Use using declaration instead of typedef
- Align pointers to the type instead of the identifier
2024-07-09 13:53:50 +08:00
David Mak bc40a32524 core: Add report_type_error to enable more code reuse 2024-07-09 13:44:47 +08:00
David Mak c820daf5f8 core: Apply cargo format 2024-07-09 13:32:10 +08:00
David Mak 25d2de67f7 standalone: Add output_range and tests 2024-07-09 04:44:40 +08:00
David Mak 2cfb7a7e10 core: Refactor range function into constructor 2024-07-09 04:44:40 +08:00
David Mak 9238a5e86e standalone: Rename output_str to output_strln and add output_str
output_str is for outputting strings without newline, and the newly
introduced output_strln now has the old behavior of ending with a
newline.
2024-07-09 04:44:40 +08:00
lyken 76defac462 meta: use clang -x c++ instead of clang++ 2024-07-07 20:03:34 +08:00
lyken 650f354b74 core: use C++ for irrt source 2024-07-07 14:36:10 +08:00
abdul124 f062ef5f59 core/llvm_intrinsic: replace roundeven with rint 2024-07-07 14:24:18 +08:00
lyken f52086b706 core: improve binop and cmpop error messages 2024-07-05 16:27:24 +08:00
lyken 0a732691c9 core: refactor typecheck/magic_methods.rs operators & add op symbol name 2024-07-05 16:27:20 +08:00
lyken cbff356d50 core: workaround inkwell on `llvm.stackrestore` 2024-07-05 13:56:12 +08:00
lyken 24ac3820b2 core: check int32 obj_id directly in fold_numpy_function_call_shape_argument 2024-07-05 10:36:47 +08:00
David Mak ba32fab374 standalone: Add demos for list arithmetic operators 2024-07-04 16:01:15 +08:00
David Mak c4052b6342 core: Implement multi-operand __eq__ and __ne__ for lists 2024-07-04 16:01:15 +08:00
David Mak 66c205275f core: Implement list::__add__ 2024-07-04 16:01:11 +08:00
David Mak c85e412206 core: Implement list::__mul__ 2024-07-04 15:53:50 +08:00
David Mak 075536d7bd core: Add BreakContinueHooks for gen_for_callback 2024-07-04 15:32:18 +08:00
David Mak 13beeaa2bf core: Implement handling for zero-length lists 2024-07-04 15:32:18 +08:00
David Mak 2194dbddd5 core/type_annotation: Refactor List type to TObj
In preparation for operators on lists.
2024-07-04 15:32:18 +08:00
David Mak 94a1d547d6 meta: Update dependencies 2024-07-04 15:32:18 +08:00
lyken d6565feed3 core: ndarray_from_ndlist_impl cast size_of to usize 2024-07-04 12:24:52 +08:00
abdul124 83154ef8e1 core/llvm_intrinsics: remove llvm.roundeven call from call_float_roundeven 2024-07-03 14:17:47 +08:00
lyken 0744b938b8 core: fix __nac3_ndarray_calc_size crash due to incorrect typing 2024-07-03 13:03:14 +08:00
lyken 56fa2b6803 core: fix crash on iterating over non-iterables
a
2024-06-28 15:45:53 +08:00
lyken d06c13f936 core: fix crash on invalid subscripting 2024-06-27 16:58:48 +08:00
lyken 9808923258 core: improve comments in type_inferencer/mod.rs 2024-06-27 14:46:48 +08:00
lyken 5b11a1dbdd core: support tuple and int32 input for np_empty, np_ones, and more 2024-06-27 14:30:17 +08:00
lyken b21df53e0d core: fix comment typo in unify_call() 2024-06-27 14:06:39 +08:00
lyken 0ec967a468 core: improve function call errors 2024-06-27 14:06:39 +08:00
lyken ca8459dc7b standalone: prettify TopLevelComposer error reporting 2024-06-27 10:15:14 +08:00
abdul124 b0b804051a nac3artiq: allow class attribute access without init function 2024-06-25 16:06:33 +08:00
abdul124 134af79fd6 core: add support for class attributes 2024-06-25 16:06:33 +08:00
abdul124 7fe2c3496c core: add attribute field to class definition 2024-06-25 16:06:33 +08:00
lyken 144a3fc426 core: more derive Debug in typedef 2024-06-25 15:02:50 +08:00
lyken 74096eb9f6 core: name codegen worker threads 2024-06-25 12:36:37 +08:00
lyken 06e9d90d57 apply clippy changes 2024-06-21 14:14:01 +08:00
lyken d89146aa02 core: use no_run on builtin_fns docs 2024-06-20 13:53:25 +08:00
David Mak 5bade81ddb standalone: Add test for multidim array index with one index 2024-06-20 12:50:30 +08:00
David Mak 0452e6de78 core: Fix codegen for tuple-index into ndarray 2024-06-20 12:50:30 +08:00
David Mak 635c944c90 core: Fix type inference for tuple-index into ndarray
Fixes #420.
2024-06-20 12:50:30 +08:00
lyken e36af3b0a3 core: reduce code duplication in codegen/builtin_fns (#422)
Used macros to generate some unary math functions.

Reviewed-on: M-Labs/nac3#422
Reviewed-by: David Mak <chmakac@connect.ust.hk>
Co-authored-by: lyken <lyken@m-labs.hk>
Co-committed-by: lyken <lyken@m-labs.hk>
2024-06-20 12:48:44 +08:00
Sebastien Bourdeauducq 5b1aa812ed update dependencies 2024-06-20 10:43:55 +08:00
David Mak d3cd2a8d99 artiq: Add support for generating RPC tag for ndarray 2024-06-19 18:56:16 +08:00
David Mak 202a63274d artiq: Implement pyty-to-ty conversion 2024-06-19 18:56:15 +08:00
David Mak 76dd5191f5 artiq: Implement Python-to-LLVM conversion of ndarray 2024-06-19 18:56:15 +08:00
David Mak 8d9df0a615 artiq: Fix ndarray class ID
We want the class ID of the ndarray class, not its corresponding typing
class.
2024-06-19 18:56:15 +08:00
David Mak 07adfb2a18 standalone: Add *.ll to Gitignore list 2024-06-19 18:56:15 +08:00
Sébastien Bourdeauducq f00e458f60 add test for class without __init__ 2024-06-19 18:16:54 +08:00
David Mak 1bc95a7ba6 Add handling for np.bool_ and np.str_ 2024-06-19 15:10:47 +08:00
lyken e85f4f9bd2 core: refactor top_level::builtins::get_builtins() 2024-06-18 11:06:25 +08:00
abdul124 ce3e9bf4fe nac3artiq: add support string attributes in classes 2024-06-17 16:53:51 +08:00
David Mak 82091b1be8 meta: Apply clippy changes 2024-06-17 14:10:31 +08:00
David Mak 32919949e2 Run clippy --tests on pre-commit hook 2024-06-17 12:51:25 +08:00
lyken 2abe75d1f4 core: remove code dup with `make_exception_fields` 2024-06-17 12:01:48 +08:00
lyken 676412fe6d apply cargo fmt 2024-06-14 09:46:42 +08:00
lyken 8b9df7252f core: cleanup with Unifier::generate_var_id 2024-06-14 09:42:04 +08:00
lyken 6979843431 core: fix typo in into_var_map 2024-06-13 16:59:10 +08:00
lyken fed1361c6a core: rename to_var_map to into_var_map 2024-06-13 16:59:10 +08:00
lyken aa94e0c8a4 core: remove pub & add From<TypeVarId> for u32 2024-06-13 16:59:10 +08:00
lyken f523e26227 core: fix typo in fmt::Display of TypeVarId 2024-06-13 16:59:10 +08:00
lyken f026b48e2a core: refactor to use `TypeVarId` and `TypeVar` 2024-06-13 16:59:10 +08:00
lyken dc874f2994 core: use `PrimDef` simple names in make_primitives() 2024-06-13 16:58:32 +08:00
lyken 95de0800b4 core/demo: fix typo in .gitignore 2024-06-13 16:05:33 +08:00
lyken 3d71c6a850 core/demo: gitignore to ignore *.bc & *.o 2024-06-13 16:00:23 +08:00
David Mak be55e2ac80 meta: Update README to include info regarding pre-commit hooks 2024-06-12 16:10:57 +08:00
David Mak 79c8b759ad meta: Add pre-commit configuration 2024-06-12 16:10:57 +08:00
David Mak 4798c53a21 flake: Add pre-commit to dev environment 2024-06-12 16:10:57 +08:00
David Mak 23974feae7 meta: Restrict number of allowed lints 2024-06-12 16:10:57 +08:00
David Mak 40a3bded36 meta: Set clippy lints in {main,lib}.rs
So that this does not have to be manually passed to the `cargo clippy`
command-line every single time. Also allows incrementally addressing
these lints by removing and fixing them one-by-one.
2024-06-12 16:10:57 +08:00
lyken c4420e6ab9 core: refactor `get_builtins()` 2024-06-12 15:09:20 +08:00
lyken fd36f78005 core: refactor `PrimitiveDefinitionId` into enum `PrimDef` 2024-06-12 15:01:01 +08:00
lyken 8168692cc3 apply cargo fmt 2024-06-12 14:45:03 +08:00
David Mak 53d44b9595 standalone: Add np_array tests 2024-06-11 16:44:36 +08:00
David Mak 6153f94b05 core/numpy: Implement codegen for np_array 2024-06-11 16:42:11 +08:00
David Mak 4730b595f3 core/builtins: Add np_array function 2024-06-11 16:42:08 +08:00
David Mak c2fdb12397 core/type_inferencer: Add special rule for np_array 2024-06-11 16:40:35 +08:00
David Mak 82bf14785b core: Add multidimensional array helpers 2024-06-11 15:30:06 +08:00
David Mak 2d4329e23c core/stmt: Use BB of last statement in if-else in phi 2024-06-11 15:30:06 +08:00
David Mak 679656f9e1 core/classes: Fix incorrect field locations for lists 2024-06-11 15:30:06 +08:00
David Mak 210d9e2334 core: Add more creator functions for ProxyType 2024-06-11 15:26:37 +08:00
David Mak 181ac3ec1a core/classes: Fix incorrect pointers of range.{stop,step} 2024-06-11 15:13:31 +08:00
David Mak 3acdfb304d meta: Apply clippy suggestions 2024-06-11 14:58:32 +08:00
David Mak 6e24da9cc5 meta: Update dependencies 2024-06-11 14:58:32 +08:00
David Mak f0ab1b858a core: Refactor class abstractions
- Introduce new Type abstractions
- Rearrange some functions
2024-06-06 13:45:51 +08:00
lyken 08129cc635 nac3core: add TopLevelComposer::new builtin check's assertion msg 2024-06-05 15:30:02 +08:00
David Mak ad4832dcf4 core: Refactor to get LLVM intrinsics via Intrinsics::find 2024-06-05 15:29:40 +08:00
lyken 520bbb246b flake: add llvmPackages_14.llvm to devShells linux default (#405)
Co-authored-by: lyken <lyken@m-labs.hk>
Co-committed-by: lyken <lyken@m-labs.hk>
2024-06-05 11:11:56 +08:00
lyken b857f1e403 nac3core: fix typo in gen_for's comment 2024-06-04 17:15:41 +08:00
Sebastien Bourdeauducq fa8af37e84 flake: update nixpkgs 2024-06-03 22:22:04 +08:00
David Mak 23b2fee4e7 standalone: Add test case for ndarray slicing 2024-06-03 16:40:05 +08:00
David Mak ed79d5bb9e core/expr: Add support for multi-dim slicing of NDArrays 2024-06-03 16:40:05 +08:00
David Mak c35ad06949 core/expr: Add support for 1D slicing of NDArrays 2024-06-03 16:40:05 +08:00
David Mak 135ef557f9 core/numpy: Implement ndarray_sliced_{copy,copyto_impl}
Performing copying with optional support for slicing. Also made
copy_impl delegate to sliced_copy, as sliced_copy now performs a
superset of operations that copy_impl can already do.
2024-06-03 16:40:05 +08:00
David Mak a176c3eb70 core/irrt: Change handle_slice_indices to instead take length of object
So that all other array-like datatypes (e.g. ndarray) can also take
advantage of it.
2024-06-03 16:40:05 +08:00
David Mak 2cf79510c2 core/numpy: Add more helper functions 2024-06-03 16:40:05 +08:00
David Mak b6ff75dcaf core/irrt: Add support for calculating partial size of NDArray 2024-06-03 16:40:05 +08:00
David Mak 588c15f80d core/stmt: Add gen_for_range_callback
For generating for loops over range objects or array slices.
2024-06-03 16:40:05 +08:00
David Mak 82cc693b11 meta: Update dependencies 2024-06-03 16:40:02 +08:00
David Mak 520e1adc56 core/builtins: Add np_minimum/np_maximum 2024-05-09 15:01:20 +08:00
David Mak 73e81259f3 core/builtins: Add np_min/np_max 2024-05-09 15:01:20 +08:00
David Mak 7627acea41 core/type_inferencer: Fix error message 2024-05-09 15:01:20 +08:00
David Mak a777099ea8 core/type_inferencer: Fix missing lowering for some builtin TVars 2024-05-09 15:01:20 +08:00
David Mak 876e6ea7b8 meta: Update dependencies 2024-05-08 17:27:38 +08:00
David Mak 30c6cffbad core/builtins: Refactored numpy builtins to accept scalar and ndarrays 2024-05-06 15:38:29 +08:00
David Mak 51671800b6 core/builtins: Extract codegen portion into functions
We will need to reuse them when implementing elementwise function
application for ndarrays.
2024-05-06 13:21:54 +08:00
David Mak 7195476edb core/builtins: Add llvm_intrinsics prefix 2024-05-06 13:21:54 +08:00
David Mak eecba0b71d core: Add GenCall::create_dummy
A simple abstraction for GenCalls that are already handled elsewhere.
2024-05-06 13:21:54 +08:00
David Mak 7b4253ccd8 core/numpy: Add missing lifetime parameters 2024-05-06 13:21:54 +08:00
David Mak f58c3a11f8 core/builtins: Rework handling of PrimitiveStore-Unifier tuples 2024-05-06 13:21:54 +08:00
David Mak d0766a116f core: Remove Box from GenCallCallback type alias
So that references to the function type can be taken.
2024-05-06 13:21:54 +08:00
David Mak 64a3751fc2 core: Remove custom function type definitions for ndarray operators 2024-05-06 13:21:54 +08:00
David Mak 9566047241 standalone: Fix cbrt never tested 2024-05-06 13:21:54 +08:00
David Mak 062e318dd5 core/magic_methods: Fix clippy warnings 2024-05-06 13:21:54 +08:00
David Mak c4dc36ae99 standalone: Add explicit `--` for delimiting run args vs NAC3 args 2024-05-06 13:21:54 +08:00
David Mak baac348ee6 meta: Update dependencies 2024-05-06 13:21:37 +08:00
David Mak 847615fc2f core: Implement numpy.matmul for 2D-2D ndarrays 2024-04-23 10:27:37 +08:00
David Mak 5dfcc63978 core/classes: Take reference of indexes 2024-04-16 17:20:24 +08:00
David Mak 025b3cd02f core/stmt: Remove gen_if_chained*
Turns out it is really difficult to get lifetimes and closures right, so
let's just provide the most rudimentary if-else codegen and we can nest
them if necessary.
2024-04-16 17:16:50 +08:00
David Mak e0f440040c core/expr: Implement negative indices for ndarray 2024-04-15 12:49:42 +08:00
David Mak f0715e2b6d core/stmt: Add gen_if* functions
For generating if-constructs in IR.
2024-04-15 12:20:34 +08:00
David Mak e7fca67786 core/stmt: Do not generate jumps if bb is already terminated
Future-proofs gen_*_callback functions in case other codegen functions
will delegate to it in the future.
2024-04-15 12:20:34 +08:00
David Mak 52c731c312 core: Implement Not/UAdd/USub for booleans
Not sure if this is deliberate or an oversight, but we implement it
anyway for consistency with other Python implementations.
2024-04-12 18:29:58 +08:00
David Mak 00d1b9be9b core: Fix __inv__ for i8-based boolean operands 2024-04-12 15:35:54 +08:00
David Mak 8404d4c4dc meta: Update dependencies 2024-04-12 15:29:09 +08:00
David Mak e614dd4257 core/type_inferencer: Fix location of unary/compare expressions
Codegen uses this location information to determine the CallId, and if
a function call is the operand of a unary expression or left-hand
operand of a compare expression, codegen will use the type of the
operator expression rather than the actual operand type.
2024-04-05 15:42:10 +08:00
David Mak 937a8b9698 core/magic_methods: Fix type of unary ops with primitive types 2024-04-05 13:23:08 +08:00
David Mak 876ad6c59c core/type_inferencer: Include location info if inferencer fails 2024-04-05 13:22:35 +08:00
David Mak a920fe0501 core: Implement elementwise comparison operators 2024-04-03 00:07:33 +08:00
David Mak 727a1886b3 core: Implement elementwise unary operators 2024-04-03 00:07:33 +08:00
David Mak 6af13a8261 core: Implement elementwise binary operators
Including immediate variants of these operators.
2024-04-03 00:07:33 +08:00
David Mak 3540d0ab29 core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with
primitive operands.
2024-04-03 00:07:33 +08:00
David Mak 3a6c53d760 core/toplevel/numpy: Split ndarray type var utilities 2024-04-03 00:07:33 +08:00
David Mak 87bc34f7ec core: Implement calculations for broadcasting ndarrays 2024-04-03 00:07:31 +08:00
David Mak f50a5f0345 core/type_inferencer: Allow both int32 and isize when indexing ndarray 2024-04-02 16:49:12 +08:00
David Mak a77fd213e0 core/magic_methods: Allow unknown return types
These types can be later inferred by the type inferencer.
2024-04-02 16:49:12 +08:00
David Mak 8f1497df83 core/helper: Add PrimitiveDefinitionIds::iter 2024-04-02 16:49:12 +08:00
David Mak 5ca2dbeec8 core/typedef: Add Type::obj_id to replace get_obj_id 2024-04-02 16:49:10 +08:00
David Mak 9a98cde595 core: Extract codegen portion of gen_*op_expr
This allows *ops to be generated internally using LLVM values as
input. Required in a future change.
2024-04-01 16:48:25 +08:00
David Mak 5ba8601b39 core: Remove ArrayValue variants of functions
These will be lowered and optimized away later anyways, and we have
ArrayLikeAccessor now.
2024-04-01 16:48:25 +08:00
David Mak 26a01b14d5 core: Use more typed slices in APIs 2024-04-01 16:48:25 +08:00
David Mak d5f4817134 core/builtins: Fix len() on ndarrays 2024-04-01 16:48:24 +08:00
David Mak 789bfb5a26 core: Fix index-based operations not returning i32 2024-04-01 16:46:45 +08:00
David Mak 4bb0e60981 core: Apply clippy suggestions 2024-04-01 16:46:41 +08:00
Sebastien Bourdeauducq 623fcf85af msys2: update 2024-03-25 14:45:36 +08:00
David Mak 13f06f3e29 core: Refactor VarMap to IndexMap
This is the only Map I can find that preserves insertion order while
also deduplicating elements by key.
2024-03-22 15:51:23 +08:00
David Mak f0da9c0283 core: Add ArrayLikeValue
For exposing LLVM values that can be accessed like an array.
2024-03-22 15:51:06 +08:00
David Mak 2c4bf3ce59 core: Allow unsized CodeGenerator to be passed to some codegen functions
Enables codegen_callback to call these codegen functions as well.
2024-03-22 15:07:28 +08:00
David Mak e980f19c93 core: Simplify typed value assertions 2024-03-22 15:07:28 +08:00
David Mak cfbc37c1ed core: Add gen_for_callback_incrementing
Simplifies generation of monotonically increasing for loops.
2024-03-22 15:07:28 +08:00
David Mak 50264e8750 core: Add missing unchecked accessors for NDArrayDimsProxy 2024-03-22 15:07:28 +08:00
David Mak 1b77e62901 core: Split numpy into codegen and toplevel 2024-03-22 15:07:28 +08:00
David Mak fd44ee6887 core: Apply clippy suggestions 2024-03-22 15:07:23 +08:00
David Mak c8866b1534 core/classes: Rename get_* functions to remove prefix
As suggested by Rust API Guidelines.
2024-03-21 15:46:10 +08:00
David Mak 84a888758a core: Rename unsafe functions to unchecked
This is this intended name of the functions.
2024-03-21 15:46:10 +08:00
David Mak 9d550725b7 meta: Update cargo dependencies 2024-03-21 15:45:26 +08:00
David Mak 2edc1de0b6 standalone: Update ndarray.py to output all elements in ndarrays 2024-03-07 14:59:13 +08:00
David Mak c3b122acfc core: Implement `ndarray.copy` 2024-03-07 14:59:13 +08:00
David Mak a94927a11d core: Update __builtin_assume expressions
No dimension size should be 0.
2024-03-07 14:59:13 +08:00
David Mak ebf86cd134 core: Use size_t for accessing array elements 2024-03-07 14:59:13 +08:00
David Mak cccd8f2d00 core: Fix ndarray_eye not preserving signness of offset 2024-03-07 14:59:13 +08:00
David Mak 3292aed099 core: Fix ndarray subscript operator returning the wrong object
Should be returning the newly created object instead of the original
ndarray...
2024-03-07 14:59:13 +08:00
David Mak 96b7f29679 core: Implement `ndarray.fill` 2024-03-07 14:59:13 +08:00
David Mak 3d2abf73c8 core: Replace ndarray_init_dims IRRT impl with IR impl
Implementation of that function in IR allows for more flexibility in
terms of different integer type widths.
2024-03-07 14:59:13 +08:00
David Mak f682e9bf7a core: Match IRRT compile flavor with build profile 2024-03-07 14:59:02 +08:00
David Mak b26cb2b360 core: Express member func def IDs as offsets from class def ID 2024-03-06 12:24:39 +08:00
David Mak 2317516cf6 core: Use tvars from ndarray for class definition 2024-03-04 23:58:02 +08:00
David Mak 77de24ef74 core: Use BTreeMap for type variable mapping
There have been multiple instances where I had the need to iterate over
type variables, only to discover that the traversal order is arbitrary.

This commit fixes that by adding SortedMapping, which utilizes BTreeMap
internally to guarantee a traversal order. All instances of VarMap are
now refactored to use this to ensure that type variables are iterated in
 the order of its variable ID, which should be monotonically incremented
 by the unifier.
2024-03-04 23:56:04 +08:00
David Mak 234a6bde2a core: Use TObj for NDArray 2024-03-01 15:41:55 +08:00
David Mak c3db6297d9 core: Add primitive definition-id list
So that we have a single ground truth for the definition IDs of
primitive types.
2024-03-01 11:29:10 +08:00
David Mak 82fdb02d13 core: Extract LLVM intrinsic functions to their functions 2024-02-23 15:41:06 +08:00
David Mak 4efdd17513 core: Add missing From implementations for LLVM wrapper classes 2024-02-23 15:41:06 +08:00
David Mak 49de81ef1e core: Apply clippy suggestions 2024-02-23 15:41:06 +08:00
David Mak 8492503af2 core: Update cargo dependencies 2024-02-23 15:41:04 +08:00
Sebastien Bourdeauducq e1dbe2526a flake: switch to nixpkgs unstable for newer rustc 2024-02-20 15:46:51 +08:00
Sebastien Bourdeauducq f37de381ce update dependencies 2024-02-20 13:33:20 +08:00
Sebastien Bourdeauducq 4452c8986a update ARTIQ version used for PGO profiling 2024-02-20 13:29:00 +08:00
David Mak 22e831cb76 core: Add test for indexing into ndarray 2024-02-19 17:13:10 +08:00
David Mak cc538d221a core: Implement codegen for indexing into ndarray 2024-02-19 17:13:09 +08:00
David Mak 0d5c53e60c core: Implement type inference for indexing into ndarray 2024-02-19 17:13:09 +08:00
David Mak 976a9512c1 core: Add const variants to NDArray element getters 2024-02-19 16:56:21 +08:00
David Mak 1eacaf9afa core: Fix IRRT argument order to ndarray_flatten_index 2024-02-19 16:37:13 +08:00
David Mak 8c7e44098a core: Fix IRRT implementation of ndarray_flatten_index 2024-02-19 16:37:13 +08:00
David Mak 282a3e1911 core: Fix typo in error message 2024-02-14 16:26:13 +08:00
David Mak 5cecb2bb74 core: Fix Literal use in variable type annotation 2024-02-06 18:16:14 +08:00
David Mak 1963c30744 core: Use Display output for locations 2024-02-06 18:11:51 +08:00
David Mak 27011f385b core: Add location to non-primitive value return error 2024-02-02 12:49:21 +08:00
David Mak d6302b6ec8 core: Allow tuple of primitives to be returned 2024-02-02 12:48:52 +08:00
David Mak fef4b2a5ce standalone: Disable tests requiring return of non-primitive values 2024-01-29 12:49:50 +08:00
David Mak b3736c3e99 core: Disallow returning of non-primitive values
Non-primitive values are represented by an `alloca`-ed value in the
function body, and when the pointer is returned from the function, the
`alloca`-ed object is deallocated on the stack.

Related to #54.
2024-01-29 12:49:24 +08:00
Sebastien Bourdeauducq e328e44c9a update MSYS2 2024-01-26 15:55:45 +08:00
Sebastien Bourdeauducq 9e4e90f8a0 update dependencies 2024-01-26 15:52:48 +08:00
David Mak 8470915809 core: Add NDArrayValue and helper functions 2024-01-25 15:51:39 +08:00
David Mak 148900302e core: Add RangeValue and helper functions 2024-01-25 15:51:39 +08:00
David Mak 5ee08b585f core: Add ListValue and helper functions 2024-01-25 15:51:39 +08:00
David Mak f1581299fc core: Minor changes to IRRT
Add missing documentation, remove redundant lifetime variables, and fix
typos.
2024-01-25 15:50:53 +08:00
David Mak af95ba5012 standalone: Add debug flag to run_demo.sh
Allows running demos using the debug build instead of the (default)
release build.
2024-01-25 15:50:53 +08:00
David Mak 9c9756be33 standalone: Use size_t in demo.c 2024-01-25 15:50:53 +08:00
David Mak 2a922c7480 artiq: Fix source module of NDArray
Should be `numpy.typing` instead of `numpy`.
2024-01-17 10:40:08 +08:00
David Mak e3e2c36ef4 core: Mark TNDArray and TLiteral as unimplemented in tests 2024-01-17 09:58:14 +08:00
David Mak 4f9a0110c4 meta: Update insta snapshots 2024-01-17 09:49:50 +08:00
David Mak 12c0eed0a3 core: Fix compilation of tests 2024-01-17 09:49:49 +08:00
David Mak c679474f5c standalone: Fix redefinition of ndarray consumer functions 2024-01-17 09:38:13 +08:00
Sébastien Bourdeauducq ab3fa05996 demo: use portable format strings 2024-01-10 18:35:35 +08:00
David Mak 140f8f8a08 core: Implement most ndarray-creation functions 2023-12-22 16:29:55 +08:00
David Mak 27fcf8926e core: Implement ndarray constructor and numpy.empty 2023-12-22 16:29:54 +08:00
David Mak afa7d9b100 core: Implement helper for creation of generic ndarray 2023-12-21 15:39:49 +08:00
David Mak c395472094 core: Initial infrastructure for ndarray 2023-12-21 15:39:46 +08:00
David Mak 03870f222d core: Extract special method handling in type inferencer
To prepare for more special handling with methods.
2023-12-21 15:38:26 +08:00
David Mak e435b25756 core: Allow implicit promotions of integral literals
It should not matter, since it is the value of the literal that matters
with respect to the const generic variable.
2023-12-21 15:21:08 +08:00
David Mak bd792904f9 core: Add size_t to primitive store
Used for ndims in ndarray.
2023-12-21 15:20:31 +08:00
David Mak 1c3a823670 core: Do not discard value names for IRRT 2023-12-20 15:16:02 +08:00
David Mak f01d833d48 standalone: Add missing parenthesis 2023-12-20 15:15:47 +08:00
David Mak 9d64e606f4 core: Reject multiple literal bounds
This is currently broken due to how we handle function calls in the
unifier.
2023-12-18 10:04:25 +08:00
David Mak 6dccb343bb Revert "core: Do not keep unification result for function arguments"
This reverts commit f09f3c27a5.
2023-12-18 10:01:23 +08:00
Sebastien Bourdeauducq d47534e2ad interpret_demo: add typing.Literal 2023-12-18 08:50:49 +08:00
David Mak 8886964776 core: Remove redundant argument in type annotation parsing 2023-12-16 18:40:48 +08:00
David Mak f09f3c27a5 core: Do not keep unification result for function arguments
For some reason, when unifying a function call parameter with an
argument, subsequent calls to the same function will only accept the
type of the substituted argument.

This affect snippets like:

```
def make1() -> C[Literal[1]]:
    return ...

def make2() -> C[Literal[2]]:
    return ...

def consume(instance: C[Literal[1, 2]]):
    pass

consume(make1())
consume(make2())
```

The last statement will result in a compiler error, as the parameter of
consume is replaced with C[Literal[1]].

We fix this by getting a snapshot before performing unification, and
restoring the snapshot after unification succeeds.
2023-12-16 18:40:48 +08:00
David Mak 0bbc9ce6f5 core: Deduplicate values in `Literal`
Matches the behavior with `typing.Literal`.
2023-12-16 18:40:48 +08:00
David Mak 457d3b6cd7 core: Refactor generic constants to `Literal`
Better matches the syntax of `typing.Literal`.
2023-12-16 18:40:48 +08:00
David Mak 5f692debd8 core: Add PrimitiveStore into Unifier
This will be used during unification between a const generic variable
and a `Literal`.
2023-12-16 18:40:48 +08:00
David Mak c7735d935b standalone: Output id of undefined identifier 2023-12-16 18:40:48 +08:00
David Mak b47ac1b89b core: Minor formatting cleanup 2023-12-15 17:46:44 +08:00
98 changed files with 22447 additions and 7718 deletions

1
.clippy.toml Normal file
View File

@ -0,0 +1 @@
doc-valid-idents = ["CPython", "NumPy", ".."]

24
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,24 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [commit]
repos:
- repo: local
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [fmt]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [clippy, --tests]

748
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -51,3 +51,12 @@ Use ``nix develop`` in this repository to enter a development shell.
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``. If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``. Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.
### Pre-Commit Hooks
You are strongly recommended to use the provided pre-commit hooks to automatically reformat files and check for non-optimal Rust practices using Clippy. Run `pre-commit install` to install the hook and `pre-commit` will automatically run `cargo fmt` and `cargo clippy` for you.
Several things to note:
- If `cargo fmt` or `cargo clippy` returns an error, the pre-commit hook will fail. You should fix all errors before trying to commit again.
- If `cargo fmt` reformats some files, the pre-commit hook will also fail. You should review the changes and, if satisfied, try to commit again.

View File

@ -2,16 +2,16 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1701389149, "lastModified": 1720418205,
"narHash": "sha256-rU1suTIEd5DGCaAXKW6yHoCfR1mnYjOXQFOaH7M23js=", "narHash": "sha256-cPJoFPXU44GlhWg4pUk9oUPqurPlCFZ11ZQPk21GTPU=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "5de0b32be6e85dc1a9404c75131316e4ffbc634c", "rev": "655a58a72a6601292512670343087c2d75d859c1",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "NixOS", "owner": "NixOS",
"ref": "nixos-23.11", "ref": "nixos-unstable",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }

View File

@ -1,7 +1,7 @@
{ {
description = "The third-generation ARTIQ compiler"; description = "The third-generation ARTIQ compiler";
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-23.11; inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
outputs = { self, nixpkgs }: outputs = { self, nixpkgs }:
let let
@ -13,6 +13,7 @@
'' ''
mkdir -p $out/bin mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.clang}/bin/clang $out/bin/clang-irrt-test
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
''; '';
nac3artiq = pkgs.python3Packages.toPythonModule ( nac3artiq = pkgs.python3Packages.toPythonModule (
@ -23,16 +24,19 @@
cargoLock = { cargoLock = {
lockFile = ./Cargo.lock; lockFile = ./Cargo.lock;
}; };
cargoTestFlags = [ "--features" "test" ];
passthru.cargoLock = cargoLock; passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ]; nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_14.clang llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase = checkPhase =
'' ''
echo "Running Cargo tests..." echo "Checking nac3standalone demos..."
pushd nac3standalone/demo pushd nac3standalone/demo
patchShebangs . patchShebangs .
./check_demos.sh
popd popd
echo "Running Cargo tests..."
cargoCheckHook cargoCheckHook
''; '';
installPhase = installPhase =
@ -92,8 +96,8 @@
(pkgs.fetchFromGitHub { (pkgs.fetchFromGitHub {
owner = "m-labs"; owner = "m-labs";
repo = "artiq"; repo = "artiq";
rev = "8b4572f9cad34ac0c2b6f6bba9382e7b59b2f93b"; rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
sha256 = "sha256-O/0sUSxxXU1AL9cmT9qdzCkzdOKREBNftz22/8ouQcc="; sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
}) })
]; ];
buildInputs = [ buildInputs = [
@ -147,7 +151,7 @@
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
llvmPackages_14.clang # demo llvmPackages_14.clang llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc
@ -157,8 +161,12 @@
# development tools # development tools
cargo-insta cargo-insta
clippy clippy
pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
# https://nixos.wiki/wiki/Rust#Shell.nix_example
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -9,16 +9,16 @@ name = "nac3artiq"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
itertools = "0.12" itertools = "0.13"
pyo3 = { version = "0.20", features = ["extension-module"] } pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
parking_lot = "0.12" parking_lot = "0.12"
tempfile = "3.8" tempfile = "3.10"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" } nac3ld = { path = "../nac3ld" }
[dependencies.inkwell] [dependencies.inkwell]
version = "0.2" version = "0.4"
default-features = false default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"] features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
core: KernelInvariant[Core]
attr1: KernelInvariant[str]
attr2: KernelInvariant[int32]
def __init__(self):
self.core = Core()
self.attr2 = 32
self.attr1 = "SAMPLE"
@kernel
def run(self):
print_int32(self.attr2)
self.attr1
if __name__ == "__main__":
Demo().run()

View File

@ -0,0 +1,40 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

View File

@ -1,12 +1,13 @@
use nac3core::{ use nac3core::{
codegen::{ codegen::{
expr::gen_call, expr::gen_call,
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
stmt::{gen_block, gen_with}, stmt::{gen_block, gen_with},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, GenCall}, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum} typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap},
}; };
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
@ -15,7 +16,10 @@ use inkwell::{
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace, context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
}; };
use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}}; use pyo3::{
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
@ -41,7 +45,7 @@ enum ParallelMode {
/// ///
/// Each function call within the `with` block (except those within a nested `sequential` block) /// Each function call within the `with` block (except those within a nested `sequential` block)
/// are treated to be executed in parallel. /// are treated to be executed in parallel.
Deep Deep,
} }
pub struct ArtiqCodeGenerator<'a> { pub struct ArtiqCodeGenerator<'a> {
@ -60,7 +64,7 @@ pub struct ArtiqCodeGenerator<'a> {
end: Option<Expr<Option<Type>>>, end: Option<Expr<Option<Type>>>,
timeline: &'a (dyn TimeFns + Sync), timeline: &'a (dyn TimeFns + Sync),
/// The [ParallelMode] of the current parallel context. /// The [`ParallelMode`] of the current parallel context.
/// ///
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel` /// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
/// statement, which is used to determine when and how the timeline should be updated. /// statement, which is used to determine when and how the timeline should be updated.
@ -91,14 +95,13 @@ impl<'a> ArtiqCodeGenerator<'a> {
/// ///
/// Direct-`parallel` block context refers to when the generator is generating statements whose /// Direct-`parallel` block context refers to when the generator is generating statements whose
/// closest parent `with` statement is a `with parallel` block. /// closest parent `with` statement is a `with parallel` block.
fn timeline_reset_start( fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
&mut self,
ctx: &mut CodeGenContext<'_, '_>
) -> Result<(), String> {
if let Some(start) = self.start.clone() { if let Some(start) = self.start.clone() {
let start_val = self.gen_expr(ctx, &start)? let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(
.unwrap() ctx,
.to_basic_value_enum(ctx, self, start.custom.unwrap())?; self,
start.custom.unwrap(),
)?;
self.timeline.emit_at_mu(ctx, start_val); self.timeline.emit_at_mu(ctx, start_val);
} }
@ -124,30 +127,22 @@ impl<'a> ArtiqCodeGenerator<'a> {
store_name: Option<&str>, store_name: Option<&str>,
) -> Result<(), String> { ) -> Result<(), String> {
if let Some(end) = end { if let Some(end) = end {
let old_end = self.gen_expr(ctx, &end)? let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(
.unwrap() ctx,
.to_basic_value_enum(ctx, self, end.custom.unwrap())?; self,
end.custom.unwrap(),
)?;
let now = self.timeline.emit_now_mu(ctx); let now = self.timeline.emit_now_mu(ctx);
let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| { let max =
let i64 = ctx.ctx.i64_type(); call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
ctx.module.add_function( let end_store = self
"llvm.smax.i64", .gen_store_target(
i64.fn_type(&[i64.into(), i64.into()], false),
None,
)
});
let max = ctx
.builder
.build_call(smax, &[old_end.into(), now.into()], "smax")
.try_as_basic_value()
.left()
.unwrap();
let end_store = self.gen_store_target(
ctx, ctx,
&end, &end,
store_name.map(|name| format!("{name}.addr")).as_deref())? store_name.map(|name| format!("{name}.addr")).as_deref(),
)?
.unwrap(); .unwrap();
ctx.builder.build_store(end_store, max); ctx.builder.build_store(end_store, max).unwrap();
} }
Ok(()) Ok(())
@ -170,8 +165,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>( fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
stmts: I stmts: I,
) -> Result<(), String> where Self: Sized { ) -> Result<(), String>
where
Self: Sized,
{
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement // Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
// in the parallel block // in the parallel block
if self.parallel_mode == ParallelMode::Legacy { if self.parallel_mode == ParallelMode::Legacy {
@ -215,9 +213,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::With { items, body, .. } = &stmt.node else { let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
unreachable!()
};
if items.len() == 1 && items[0].optional_vars.is_none() { if items.len() == 1 && items[0].optional_vars.is_none() {
let item = &items[0]; let item = &items[0];
@ -242,9 +238,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let old_parallel_mode = self.parallel_mode; let old_parallel_mode = self.parallel_mode;
let now = if let Some(old_start) = &old_start { let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start)? self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(
.unwrap() ctx,
.to_basic_value_enum(ctx, self, old_start.custom.unwrap())? self,
old_start.custom.unwrap(),
)?
} else { } else {
self.timeline.emit_now_mu(ctx) self.timeline.emit_now_mu(ctx)
}; };
@ -262,13 +260,13 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let start_expr = Located { let start_expr = Located {
// location does not matter at this point // location does not matter at this point
location: stmt.location, location: stmt.location,
node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, node: ExprKind::Name { id: start, ctx: *name_ctx },
custom: Some(ctx.primitives.int64), custom: Some(ctx.primitives.int64),
}; };
let start = self let start = self
.gen_store_target(ctx, &start_expr, Some("start.addr"))? .gen_store_target(ctx, &start_expr, Some("start.addr"))?
.unwrap(); .unwrap();
ctx.builder.build_store(start, now); ctx.builder.build_store(start, now).unwrap();
Ok(Some(start_expr)) as Result<_, String> Ok(Some(start_expr)) as Result<_, String>
}, },
|v| Ok(Some(v)), |v| Ok(Some(v)),
@ -277,13 +275,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let end_expr = Located { let end_expr = Located {
// location does not matter at this point // location does not matter at this point
location: stmt.location, location: stmt.location,
node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, node: ExprKind::Name { id: end, ctx: *name_ctx },
custom: Some(ctx.primitives.int64), custom: Some(ctx.primitives.int64),
}; };
let end = self let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
.gen_store_target(ctx, &end_expr, Some("end.addr"))? ctx.builder.build_store(end, now).unwrap();
.unwrap();
ctx.builder.build_store(end, now);
self.end = Some(end_expr); self.end = Some(end_expr);
self.name_counter += 1; self.name_counter += 1;
self.parallel_mode = match id.to_string().as_str() { self.parallel_mode = match id.to_string().as_str() {
@ -312,10 +308,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
// set duration // set duration
let end_expr = self.end.take().unwrap(); let end_expr = self.end.take().unwrap();
let end_val = self let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(
.gen_expr(ctx, &end_expr)? ctx,
.unwrap() self,
.to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?; end_expr.custom.unwrap(),
)?;
// inside a sequential block // inside a sequential block
if old_start.is_none() { if old_start.is_none() {
@ -396,9 +393,32 @@ fn gen_rpc_tag(
gen_rpc_tag(ctx, *ty, buffer)?; gen_rpc_tag(ctx, *ty, buffer)?;
} }
} }
TList { ty } => { TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let ty = iter_type_vars(params).next().unwrap().ty;
buffer.push(b'l'); buffer.push(b'l');
gen_rpc_tag(ctx, *ty, buffer)?; gen_rpc_tag(ctx, ty, buffer)?;
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray_ndims = if let TLiteral { values, .. } =
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
{
if values.len() != 1 {
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
}
let value = values[0].clone();
u64::try_from(value.clone())
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
} else {
unreachable!()
};
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
buffer.push(b'a');
buffer.push((ndarray_ndims & 0xFF) as u8);
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
} }
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
} }
@ -445,7 +465,7 @@ fn rpc_codegen_callback_fn<'ctx>(
format!("tagptr{}", fun.1 .0).as_str(), format!("tagptr{}", fun.1 .0).as_str(),
); );
tag_arr_ptr.set_initializer(&int8.const_array( tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::<Vec<_>>(), &tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
)); ));
tag_arr_ptr.set_linkage(Linkage::Private); tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
@ -463,23 +483,15 @@ fn rpc_codegen_callback_fn<'ctx>(
let arg_length = args.len() + usize::from(obj.is_some()); let arg_length = args.len() + usize::from(obj.is_some());
let stacksave = ctx.module.get_function("llvm.stacksave").unwrap_or_else(|| { let stackptr = call_stacksave(ctx, Some("rpc.stack"));
ctx.module.add_function("llvm.stacksave", ptr_type.fn_type(&[], false), None) let args_ptr = ctx
}); .builder
let stackrestore = ctx.module.get_function("llvm.stackrestore").unwrap_or_else(|| { .build_array_alloca(
ctx.module.add_function(
"llvm.stackrestore",
ctx.ctx.void_type().fn_type(&[ptr_type.into()], false),
None,
)
});
let stackptr = ctx.builder.build_call(stacksave, &[], "rpc.stack");
let args_ptr = ctx.builder.build_array_alloca(
ptr_type, ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false), ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr", "argptr",
); )
.unwrap();
// -- rpc args handling // -- rpc args handling
let mut keys = fun.0.args.clone(); let mut keys = fun.0.args.clone();
@ -489,10 +501,8 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
// default value handling // default value handling
for k in keys { for k in keys {
mapping.insert( mapping
k.name, .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()
);
} }
// reorder the parameters // reorder the parameters
let mut real_params = fun let mut real_params = fun
@ -511,17 +521,19 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
for (i, arg) in real_params.iter().enumerate() { for (i, arg) in real_params.iter().enumerate() {
let arg_slot = generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap(); let arg_slot =
ctx.builder.build_store(arg_slot, *arg); generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg"); ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
let arg_ptr = unsafe { let arg_ptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(
args_ptr, args_ptr,
&[int32.const_int(i as u64, false)], &[int32.const_int(i as u64, false)],
&format!("rpc.arg{i}"), &format!("rpc.arg{i}"),
) )
}; }
ctx.builder.build_store(arg_ptr, arg_slot); .unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
} }
// call // call
@ -539,18 +551,12 @@ fn rpc_codegen_callback_fn<'ctx>(
None, None,
) )
}); });
ctx.builder.build_call( ctx.builder
rpc_send, .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
&[service_id.into(), tag_ptr.into(), args_ptr.into()], .unwrap();
"rpc.send",
);
// reclaim stack space used by arguments // reclaim stack space used by arguments
ctx.builder.build_call( call_stackrestore(ctx, stackptr);
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
);
// -- receive value: // -- receive value:
// T result = { // T result = {
@ -578,41 +584,35 @@ fn rpc_codegen_callback_fn<'ctx>(
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret); let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
let need_load = !ret_ty.is_pointer_type(); let need_load = !ret_ty.is_pointer_type();
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot"); let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr"); let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb); ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb); ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr"); let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]); phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap() .unwrap()
.into_int_value(); .into_int_value();
let is_done = ctx.builder.build_int_compare( let is_done = ctx
inkwell::IntPredicate::EQ, .builder
int32.const_zero(), .build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
alloc_size, .unwrap();
"rpc.done",
);
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb); ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb); ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc"); let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr"); let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]); phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb); ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb); ctx.builder.position_at_end(tail_bb);
let result = ctx.builder.build_load(slot, "rpc.result"); let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
if need_load { if need_load {
ctx.builder.build_call( call_stackrestore(ctx, stackptr);
stackrestore,
&[stackptr.try_as_basic_value().unwrap_left().into()],
"rpc.stackrestore",
);
} }
Ok(Some(result)) Ok(Some(result))
} }
@ -633,14 +633,20 @@ pub fn attributes_writeback(
let mut scratch_buffer = Vec::new(); let mut scratch_buffer = Vec::new();
for val in (*globals).values() { for val in (*globals).values() {
let val = val.as_ref(py); let val = val.as_ref(py);
let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?; let ty = inner_resolver.get_obj_type(
py,
val,
&mut ctx.unifier,
&top_levels,
&ctx.primitives,
)?;
if let Err(ty) = ty { if let Err(ty) = ty {
return Ok(Err(ty)) return Ok(Err(ty));
} }
let ty = ty.unwrap(); let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) { match &*ctx.unifier.get_ty(ty) {
TypeEnum::TObj { fields, obj_id, .. } TypeEnum::TObj { fields, obj_id, .. }
if *obj_id != ctx.primitives.option.get_obj_id(&ctx.unifier) => if *obj_id != ctx.primitives.option.obj_id(&ctx.unifier).unwrap() =>
{ {
// we only care about primitive attributes // we only care about primitive attributes
// for non-primitive attributes, they should be in another global // for non-primitive attributes, they should be in another global
@ -648,14 +654,19 @@ pub fn attributes_writeback(
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_mutable)) in fields { for (name, (field_ty, is_mutable)) in fields {
if !is_mutable { if !is_mutable {
continue continue;
} }
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string()); attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name); let (index, _) = ctx.get_attr_index(ty, *name);
values.push((*field_ty, ctx.build_gep_and_load( values.push((
*field_ty,
ctx.build_gep_and_load(
obj.into_pointer_value(), obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)], None))); &[zero, int32.const_int(index as u64, false)],
None,
),
));
} }
} }
if !attributes.is_empty() { if !attributes.is_empty() {
@ -664,33 +675,46 @@ pub fn attributes_writeback(
pydict.set_item("fields", attributes)?; pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?; host_attributes.append(pydict)?;
} }
}, }
TypeEnum::TList { ty: elem_ty } => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() { let elem_ty = iter_type_vars(params).next().unwrap().ty;
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py); let pydict = PyDict::new(py);
pydict.set_item("obj", val)?; pydict.set_item("obj", val)?;
host_attributes.append(pydict)?; host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); values.push((
ty,
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
));
}
} }
},
_ => {} _ => {}
} }
} }
let fun = FunSignature { let fun = FunSignature {
args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg { args: values
.iter()
.enumerate()
.map(|(i, (ty, _))| FuncArg {
name: i.to_string().into(), name: i.to_string().into(),
ty: *ty, ty: *ty,
default_value: None default_value: None,
}).collect(), })
.collect(),
ret: ctx.primitives.none, ret: ctx.primitives.none,
vars: HashMap::default() vars: VarMap::default(),
}; };
let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); let args: Vec<_> =
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, DefinitionId(0)), args, generator) { values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) =
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
{
return Ok(Err(e)); return Ok(Err(e));
} }
Ok(Ok(())) Ok(Ok(()))
}).unwrap()?; })
.unwrap()?;
Ok(()) Ok(())
} }

View File

@ -1,3 +1,21 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
unsafe_op_in_unsafe_fn,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fs; use std::fs;
use std::io::Write; use std::io::Write;
@ -14,16 +32,16 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::{CodeGenLLVMOptions, CodeGenTargetMachineOptions, gen_func_impl}; use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier}; use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap};
use nac3parser::{ use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef}, ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program, parser::parse_program,
}; };
use pyo3::create_exception;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet}; use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::create_exception;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
@ -46,7 +64,7 @@ use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback; use crate::codegen::attributes_writeback;
use crate::{ use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore}, symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
}; };
mod codegen; mod codegen;
@ -63,6 +81,17 @@ enum Isa {
CortexA9, CortexA9,
} }
impl Isa {
/// Returns the number of bits in `size_t` for the [`Isa`].
fn get_size_type(self) -> u32 {
if self == Isa::Host {
64u32
} else {
32u32
}
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct PrimitivePythonId { pub struct PrimitivePythonId {
int: u64, int: u64,
@ -73,7 +102,11 @@ pub struct PrimitivePythonId {
float: u64, float: u64,
float64: u64, float64: u64,
bool: u64, bool: u64,
np_bool_: u64,
string: u64,
np_str_: u64,
list: u64, list: u64,
ndarray: u64,
tuple: u64, tuple: u64,
typevar: u64, typevar: u64,
const_generic_marker: u64, const_generic_marker: u64,
@ -126,9 +159,7 @@ impl Nac3 {
for mut stmt in parser_result { for mut stmt in parser_result {
let include = match stmt.node { let include = match stmt.node {
StmtKind::ClassDef { StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => {
ref decorator_list, ref mut body, ref mut bases, ..
} => {
let nac3_class = decorator_list.iter().any(|decorator| { let nac3_class = decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "nac3" id.to_string() == "nac3"
@ -148,7 +179,8 @@ impl Nac3 {
if *id == "Exception".into() { if *id == "Exception".into() {
Ok(true) Ok(true)
} else { } else {
let base_obj = module.getattr(py, id.to_string().as_str())?; let base_obj =
module.getattr(py, id.to_string().as_str())?;
let base_id = id_fn.call1((base_obj,))?.extract()?; let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id)) Ok(registered_class_ids.contains(&base_id))
} }
@ -277,9 +309,11 @@ impl Nac3 {
py: Python, py: Python,
link_fn: &dyn Fn(&Module) -> PyResult<T>, link_fn: &dyn Fn(&Module) -> PyResult<T>,
) -> PyResult<T> { ) -> PyResult<T> {
let size_t = self.isa.get_size_type();
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
self.builtins.clone(), self.builtins.clone(),
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
size_t,
); );
let builtins = PyModule::import(py, "builtins")?; let builtins = PyModule::import(py, "builtins")?;
@ -327,8 +361,9 @@ impl Nac3 {
let class_obj; let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node { if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap(); let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() && if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
class.getattr("artiq_builtin").is_err() { && class.getattr("artiq_builtin").is_err()
{
class_obj = Some(class); class_obj = Some(class);
} else { } else {
class_obj = None; class_obj = None;
@ -374,12 +409,12 @@ impl Nac3 {
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path, false) .register_top_level(stmt.clone(), Some(resolver.clone()), path, false)
.map_err(|e| { .map_err(|e| {
CompileError::new_err(format!( CompileError::new_err(format!("compilation failed\n----------\n{e}"))
"compilation failed\n----------\n{e}"
))
})?; })?;
if let Some(class_obj) = class_obj { if let Some(class_obj) = class_obj {
self.exception_ids.write().insert(def_id.0, store_obj.call1(py, (class_obj, ))?.extract(py)?); self.exception_ids
.write()
.insert(def_id.0, store_obj.call1(py, (class_obj,))?.extract(py)?);
} }
match &stmt.node { match &stmt.node {
@ -456,17 +491,22 @@ impl Nac3 {
exception_ids: self.exception_ids.clone(), exception_ids: self.exception_ids.clone(),
deferred_eval_store: self.deferred_eval_store.clone(), deferred_eval_store: self.deferred_eval_store.clone(),
}); });
let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; let resolver =
Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer let (_, def_id, _) = composer
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap(); .unwrap();
let fun_signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = let signature = store.from_signature(
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); &mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) { if let Err(e) = composer.start_analysis(true) {
@ -485,13 +525,11 @@ impl Nac3 {
msg.unwrap_or(e.iter().sorted().join("\n----------\n")) msg.unwrap_or(e.iter().sorted().join("\n----------\n"))
))) )))
} else { } else {
Err(CompileError::new_err( Err(CompileError::new_err(format!(
format!(
"compilation failed\n----------\n{}", "compilation failed\n----------\n{}",
e.iter().sorted().join("\n----------\n"), e.iter().sorted().join("\n----------\n"),
), )))
)) };
}
} }
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -519,7 +557,9 @@ impl Nac3 {
py, py,
( (
id.0.into_py(py), id.0.into_py(py),
class_def.getattr(py, name.to_string().as_str()).unwrap(), class_def
.getattr(py, name.to_string().as_str())
.unwrap(),
), ),
) )
.unwrap(); .unwrap();
@ -534,7 +574,8 @@ impl Nac3 {
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write(); let mut definition = defs[def_id.0].write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
&mut *definition else { &mut *definition
else {
unreachable!() unreachable!()
}; };
@ -556,8 +597,12 @@ impl Nac3 {
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = let signature = store.from_signature(
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); &mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask { let attributes_writeback_task = CodeGenTask {
subst: Vec::default(), subst: Vec::default(),
@ -590,23 +635,28 @@ impl Nac3 {
let membuffer = membuffers.clone(); let membuffer = membuffers.clone();
py.allow_threads(|| { py.allow_threads(|| {
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) =
threads, WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
top_level.clone(),
&self.llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); let mut generator =
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback"); let module = context.create_module("attributes_writeback");
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl(&context, &mut generator, &registry, builder, module, let (_, module, _) = gen_func_impl(
attributes_writeback_task, |generator, ctx| { &context,
&mut generator,
&registry,
builder,
module,
attributes_writeback_task,
|generator, ctx| {
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes) attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
}).unwrap(); },
)
.unwrap();
let buffer = module.write_bitcode_to_memory(); let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
@ -622,13 +672,24 @@ impl Nac3 {
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap(); .unwrap();
main.link_in_module(other) main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
.map_err(|err| CompileError::new_err(err.to_string()))?;
} }
let builder = context.create_builder(); let builder = context.create_builder();
let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap(); let modinit_return = main
.get_function("__modinit__")
.unwrap()
.get_last_basic_block()
.unwrap()
.get_terminator()
.unwrap();
builder.position_before(&modinit_return); builder.position_before(&modinit_return);
builder.build_call(main.get_function("attributes_writeback").unwrap(), &[], "attributes_writeback"); builder
.build_call(
main.get_function("attributes_writeback").unwrap(),
&[],
"attributes_writeback",
)
.unwrap();
main.link_in_module(load_irrt(&context)) main.link_in_module(load_irrt(&context))
.map_err(|err| CompileError::new_err(err.to_string()))?; .map_err(|err| CompileError::new_err(err.to_string()))?;
@ -642,10 +703,7 @@ impl Nac3 {
} }
// Demote all global variables that will not be referenced in the kernel to private // Demote all global variables that will not be referenced in the kernel to private
let preserved_symbols: Vec<&'static [u8]> = vec![ let preserved_symbols: Vec<&'static [u8]> = vec![b"typeinfo", b"now"];
b"typeinfo",
b"now",
];
let mut global_option = main.get_first_global(); let mut global_option = main.get_first_global();
while let Some(global) = global_option { while let Some(global) = global_option {
if !preserved_symbols.contains(&(global.get_name().to_bytes())) { if !preserved_symbols.contains(&(global.get_name().to_bytes())) {
@ -654,7 +712,9 @@ impl Nac3 {
global_option = global.get_next_global(); global_option = global.get_next_global();
} }
let target_machine = self.llvm_options.target let target_machine = self
.llvm_options
.target
.create_target_machine(self.llvm_options.opt_level) .create_target_machine(self.llvm_options.opt_level)
.expect("couldn't create target machine"); .expect("couldn't create target machine");
@ -718,10 +778,7 @@ impl Nac3 {
} }
} }
fn link_with_lld( fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
elf_filename: String,
obj_filename: String,
) -> PyResult<()>{
let linker_args = vec![ let linker_args = vec![
"-shared".to_string(), "-shared".to_string(),
"--eh-frame-hdr".to_string(), "--eh-frame-hdr".to_string(),
@ -740,9 +797,7 @@ fn link_with_lld(
return Err(CompileError::new_err("failed to start linker")); return Err(CompileError::new_err("failed to start linker"));
} }
} else { } else {
return Err(CompileError::new_err( return Err(CompileError::new_err("linker returned non-zero status code"));
"linker returned non-zero status code",
));
} }
Ok(()) Ok(())
@ -752,7 +807,7 @@ fn add_exceptions(
composer: &mut TopLevelComposer, composer: &mut TopLevelComposer,
builtin_def: &mut HashMap<StrRef, DefinitionId>, builtin_def: &mut HashMap<StrRef, DefinitionId>,
builtin_ty: &mut HashMap<StrRef, Type>, builtin_ty: &mut HashMap<StrRef, Type>,
error_names: &[&str] error_names: &[&str],
) -> Vec<Type> { ) -> Vec<Type> {
let mut types = Vec::new(); let mut types = Vec::new();
// note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}" // note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}"
@ -765,7 +820,7 @@ fn add_exceptions(
// constructor id // constructor id
def_id + 1, def_id + 1,
&mut composer.unifier, &mut composer.unifier,
&composer.primitives_ty &composer.primitives_ty,
); );
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None)); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None));
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None)); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None));
@ -792,11 +847,11 @@ impl Nac3 {
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
}; };
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; let primitive: PrimitiveStore = TopLevelComposer::make_primitives(isa.get_size_type()).0;
let builtins = vec![ let builtins = vec![
( (
"now_mu".into(), "now_mu".into(),
FunSignature { args: vec![], ret: primitive.int64, vars: HashMap::new() }, FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new() },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| { Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
Ok(Some(time_fns.emit_now_mu(ctx))) Ok(Some(time_fns.emit_now_mu(ctx)))
}))), }))),
@ -810,11 +865,12 @@ impl Nac3 {
default_value: None, default_value: None,
}], }],
ret: primitive.none, ret: primitive.none,
vars: HashMap::new(), vars: VarMap::new(),
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_at_mu(ctx, arg); time_fns.emit_at_mu(ctx, arg);
Ok(None) Ok(None)
}))), }))),
@ -828,11 +884,12 @@ impl Nac3 {
default_value: None, default_value: None,
}], }],
ret: primitive.none, ret: primitive.none,
vars: HashMap::new(), vars: VarMap::new(),
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_delay_mu(ctx, arg); time_fns.emit_delay_mu(ctx, arg);
Ok(None) Ok(None)
}))), }))),
@ -846,8 +903,9 @@ impl Nac3 {
let types_mod = PyModule::import(py, "types").unwrap(); let types_mod = PyModule::import(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap(); let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),)) let get_attr_id = |obj: &PyModule, attr| {
.unwrap().extract().unwrap(); id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let primitive_ids = PrimitivePythonId { let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()), virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
generic_alias: ( generic_alias: (
@ -856,16 +914,22 @@ impl Nac3 {
), ),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()), none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"), typevar: get_attr_id(typing_mod, "TypeVar"),
const_generic_marker: get_id(artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap()), const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
),
int: get_attr_id(builtins_mod, "int"), int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"), int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"), int64: get_attr_id(numpy_mod, "int64"),
uint32: get_attr_id(numpy_mod, "uint32"), uint32: get_attr_id(numpy_mod, "uint32"),
uint64: get_attr_id(numpy_mod, "uint64"), uint64: get_attr_id(numpy_mod, "uint64"),
bool: get_attr_id(builtins_mod, "bool"), bool: get_attr_id(builtins_mod, "bool"),
np_bool_: get_attr_id(numpy_mod, "bool_"),
string: get_attr_id(builtins_mod, "str"),
np_str_: get_attr_id(numpy_mod, "str_"),
float: get_attr_id(builtins_mod, "float"), float: get_attr_id(builtins_mod, "float"),
float64: get_attr_id(numpy_mod, "float64"), float64: get_attr_id(numpy_mod, "float64"),
list: get_attr_id(builtins_mod, "list"), list: get_attr_id(builtins_mod, "list"),
ndarray: get_attr_id(numpy_mod, "ndarray"),
tuple: get_attr_id(builtins_mod, "tuple"), tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"), exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
@ -889,7 +953,7 @@ impl Nac3 {
llvm_options: CodeGenLLVMOptions { llvm_options: CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: Nac3::get_llvm_target_options(isa), target: Nac3::get_llvm_target_options(isa),
} },
}) })
} }
@ -939,7 +1003,7 @@ impl Nac3 {
.expect("couldn't write module to file"); .expect("couldn't write module to file");
link_with_lld( link_with_lld(
filename.to_string(), filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string() working_directory.join("module.o").to_string_lossy().to_string(),
)?; )?;
Ok(()) Ok(())
}; };
@ -987,7 +1051,7 @@ impl Nac3 {
let filename = filename_path.to_str().unwrap(); let filename = filename_path.to_str().unwrap();
link_with_lld( link_with_lld(
filename.to_string(), filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string() working_directory.join("module.o").to_string_lossy().to_string(),
)?; )?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into()) Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,12 @@
use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering}; use inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
};
use itertools::Either;
use nac3core::codegen::CodeGenContext; use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline. /// Functions for manipulating the timeline.
pub trait TimeFns { pub trait TimeFns {
/// Emits LLVM IR for `now_mu`. /// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>; fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
@ -26,32 +29,33 @@ impl TimeFns for NowPinningTimeFns64 {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = let now_hiptr = ctx
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr"); .builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { .map(BasicValueEnum::into_pointer_value)
unreachable!() .unwrap();
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}; }
.unwrap();
let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = ( let now_hi = ctx
ctx.builder.build_load(now_hiptr, "now.hi"), .builder
ctx.builder.build_load(now_loptr, "now.lo"), .build_load(now_hiptr, "now.hi")
) else { .map(BasicValueEnum::into_int_value)
unreachable!() .unwrap();
}; let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder.build_left_shift( let shifted_hi =
zext_hi, ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
i64_type.const_int(32, false), let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
"", ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "");
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").into()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -59,105 +63,100 @@ impl TimeFns for NowPinningTimeFns64 {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let BasicValueEnum::IntValue(time) = t else { let time = t.into_int_value();
unreachable!()
};
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), .builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type, i32_type,
"", "",
); )
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); .unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder.build_bitcast( let now_hiptr = ctx
now, .builder
i32_type.ptr_type(AddressSpace::default()), .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
"now.hi.addr", .map(BasicValueEnum::into_pointer_value)
); .unwrap();
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}; }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = let now_hiptr = ctx
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr"); .builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else { .map(BasicValueEnum::into_pointer_value)
unreachable!() .unwrap();
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}; }
.unwrap();
let ( let now_hi = ctx
BasicValueEnum::IntValue(now_hi), .builder
BasicValueEnum::IntValue(now_lo), .build_load(now_hiptr, "now.hi")
BasicValueEnum::IntValue(dt), .map(BasicValueEnum::into_int_value)
) = ( .unwrap();
ctx.builder.build_load(now_hiptr, "now.hi"), let now_lo = ctx
ctx.builder.build_load(now_loptr, "now.lo"), .builder
dt, .build_load(now_loptr, "now.lo")
) else { .map(BasicValueEnum::into_int_value)
unreachable!() .unwrap();
}; let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, ""); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder.build_left_shift( let shifted_hi =
zext_hi, ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
i64_type.const_int(32, false), let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
"", let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "");
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now");
let time = ctx.builder.build_int_add(now_val, dt, "time"); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift( .builder
time, .build_int_truncate(
i64_type.const_int(32, false), ctx.builder
false, .build_right_shift(time, i64_type.const_int(32, false), false, "")
"", .unwrap(),
),
i32_type, i32_type,
"time.hi", "time.hi",
); )
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); .unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
@ -174,16 +173,16 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); let now_raw = ctx
.builder
let BasicValueEnum::IntValue(now_raw) = now_raw else { .build_load(now.as_pointer_value(), "now")
unreachable!() .map(BasicValueEnum::into_int_value)
}; .unwrap();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu").into() ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -191,48 +190,44 @@ impl TimeFns for NowPinningTimeFns {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let BasicValueEnum::IntValue(time) = t else { let time = t.into_int_value();
unreachable!()
};
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift(time, i64_32, false, ""), .builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
i32_type, i32_type,
"time.hi", "time.hi",
); )
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); .unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder.build_bitcast( let now_hiptr = ctx
now, .builder
i32_type.ptr_type(AddressSpace::default()), .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
"now.hi.addr", .map(BasicValueEnum::into_pointer_value)
); .unwrap();
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}; }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
@ -240,41 +235,45 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), ""); let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) else { let dt = dt.into_int_value();
unreachable!()
};
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo"); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi"); let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val"); let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time"); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder.build_int_truncate( let time_hi = ctx
ctx.builder.build_right_shift(time, i64_32, false, "time.hi"), .builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type, i32_type,
"now_trunc", "now_trunc",
); )
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo"); .unwrap();
let now_hiptr = ctx.builder.build_bitcast( let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
now, let now_hiptr = ctx
i32_type.ptr_type(AddressSpace::default()), .builder
"now.hi.addr", .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
); .map(BasicValueEnum::into_pointer_value)
.unwrap();
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}; }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_store(now_loptr, time_lo) .build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap(); .unwrap();
} }
@ -289,7 +288,11 @@ impl TimeFns for ExternTimeFns {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
}); });
ctx.builder.build_call(now_mu, &[], "now_mu").try_as_basic_value().left().unwrap() ctx.builder
.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -300,14 +303,10 @@ impl TimeFns for ExternTimeFns {
None, None,
) )
}); });
ctx.builder.build_call(at_mu, &[t.into()], "at_mu"); ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"delay_mu", "delay_mu",
@ -315,7 +314,7 @@ impl TimeFns for ExternTimeFns {
None, None,
) )
}); });
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu"); ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
} }
} }

View File

@ -10,7 +10,7 @@ constant-optimization = ["fold"]
fold = [] fold = []
[dependencies] [dependencies]
lazy_static = "1.4" lazy_static = "1.5"
parking_lot = "0.12" parking_lot = "0.12"
string-interner = "0.14" string-interner = "0.17"
fxhash = "0.2" fxhash = "0.2"

File diff suppressed because it is too large Load Diff

View File

@ -28,12 +28,12 @@ impl From<bool> for Constant {
} }
impl From<i32> for Constant { impl From<i32> for Constant {
fn from(i: i32) -> Constant { fn from(i: i32) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
impl From<i64> for Constant { impl From<i64> for Constant {
fn from(i: i64) -> Constant { fn from(i: i64) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
@ -50,6 +50,7 @@ pub enum ConversionFlag {
} }
impl ConversionFlag { impl ConversionFlag {
#[must_use]
pub fn try_from_byte(b: u8) -> Option<Self> { pub fn try_from_byte(b: u8) -> Option<Self> {
match b { match b {
b's' => Some(Self::Str), b's' => Some(Self::Str),
@ -69,6 +70,7 @@ pub struct ConstantOptimizer {
#[cfg(feature = "constant-optimization")] #[cfg(feature = "constant-optimization")]
impl ConstantOptimizer { impl ConstantOptimizer {
#[inline] #[inline]
#[must_use]
pub fn new() -> Self { pub fn new() -> Self {
Self { _priv: () } Self { _priv: () }
} }
@ -85,14 +87,10 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> { fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node { match node.node {
crate::ExprKind::Tuple { elts, ctx } => { crate::ExprKind::Tuple { elts, ctx } => {
let elts = elts let elts =
.into_iter() elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
.map(|x| self.fold_expr(x)) let expr =
.collect::<Result<Vec<_>, _>>()?; if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
let expr = if elts
.iter()
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. }))
{
let tuple = elts let tuple = elts
.into_iter() .into_iter()
.map(|e| match e.node { .map(|e| match e.node {
@ -100,18 +98,11 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
_ => unreachable!(), _ => unreachable!(),
}) })
.collect(); .collect();
crate::ExprKind::Constant { crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
value: Constant::Tuple(tuple),
kind: None,
}
} else { } else {
crate::ExprKind::Tuple { elts, ctx } crate::ExprKind::Tuple { elts, ctx }
}; };
Ok(crate::Expr { Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
node: expr,
custom: node.custom,
location: node.location,
})
} }
_ => crate::fold::fold_expr(self, node), _ => crate::fold::fold_expr(self, node),
} }
@ -127,7 +118,7 @@ mod tests {
use crate::fold::Fold; use crate::fold::Fold;
use crate::*; use crate::*;
let location = Location::new(0, 0, Default::default()); let location = Location::new(0, 0, FileName::default());
let custom = (); let custom = ();
let ast = Located { let ast = Located {
location, location,
@ -138,18 +129,12 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 1.into(), kind: None },
value: 1.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 2.into(), kind: None },
value: 2.into(),
kind: None,
},
}, },
Located { Located {
location, location,
@ -160,26 +145,17 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 3.into(), kind: None },
value: 3.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 4.into(), kind: None },
value: 4.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 5.into(), kind: None },
value: 5.into(),
kind: None,
},
}, },
], ],
}, },
@ -187,9 +163,7 @@ mod tests {
], ],
}, },
}; };
let new_ast = ConstantOptimizer::new() let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
.fold_expr(ast)
.unwrap_or_else(|e| match e {});
assert_eq!( assert_eq!(
new_ast, new_ast,
Located { Located {
@ -199,11 +173,7 @@ mod tests {
value: Constant::Tuple(vec![ value: Constant::Tuple(vec![
1.into(), 1.into(),
2.into(), 2.into(),
Constant::Tuple(vec![ Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
3.into(),
4.into(),
5.into(),
])
]), ]),
kind: None kind: None
}, },

View File

@ -64,11 +64,4 @@ macro_rules! simple_fold {
}; };
} }
simple_fold!( simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);
usize,
String,
bool,
StrRef,
constant::Constant,
constant::ConversionFlag
);

View File

@ -2,6 +2,7 @@ use crate::{Constant, ExprKind};
impl<U> ExprKind<U> { impl<U> ExprKind<U> {
/// Returns a short name for the node suitable for use in error messages. /// Returns a short name for the node suitable for use in error messages.
#[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self { match self {
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => { ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
@ -34,10 +35,7 @@ impl<U> ExprKind<U> {
ExprKind::Starred { .. } => "starred", ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice", ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => { ExprKind::JoinedStr { values } => {
if values if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
.iter()
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
{
"f-string expression" "f-string expression"
} else { } else {
"literal" "literal"

View File

@ -1,3 +1,19 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
@ -9,6 +25,6 @@ mod impls;
mod location; mod location;
pub use ast_gen::*; pub use ast_gen::*;
pub use location::{Location, FileName}; pub use location::{FileName, Location};
pub type Suite<U = ()> = Vec<Stmt<U>>; pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,6 +1,6 @@
//! Datatypes to support source location information. //! Datatypes to support source location information.
use std::cmp::Ordering;
use crate::ast_gen::StrRef; use crate::ast_gen::StrRef;
use std::cmp::Ordering;
use std::fmt; use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
@ -22,7 +22,7 @@ impl From<String> for FileName {
pub struct Location { pub struct Location {
pub row: usize, pub row: usize,
pub column: usize, pub column: usize,
pub file: FileName pub file: FileName,
} }
impl fmt::Display for Location { impl fmt::Display for Location {
@ -35,12 +35,12 @@ impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string()); let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal { if file_cmp != Ordering::Equal {
return file_cmp return file_cmp;
} }
let row_cmp = self.row.cmp(&other.row); let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal { if row_cmp != Ordering::Equal {
return row_cmp return row_cmp;
} }
self.column.cmp(&other.column) self.column.cmp(&other.column)
@ -76,23 +76,22 @@ impl Location {
) )
} }
} }
Visualize { Visualize { loc: *self, line, desc }
loc: *self,
line,
desc,
}
} }
} }
impl Location { impl Location {
#[must_use]
pub fn new(row: usize, column: usize, file: FileName) -> Self { pub fn new(row: usize, column: usize, file: FileName) -> Self {
Location { row, column, file } Location { row, column, file }
} }
#[must_use]
pub fn row(&self) -> usize { pub fn row(&self) -> usize {
self.row self.row
} }
#[must_use]
pub fn column(&self) -> usize { pub fn column(&self) -> usize {
self.column self.column
} }

View File

@ -1,3 +1,6 @@
[features]
test = []
[package] [package]
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"
@ -5,14 +8,17 @@ authors = ["M-Labs"]
edition = "2021" edition = "2021"
[dependencies] [dependencies]
itertools = "0.12" itertools = "0.13"
crossbeam = "0.8" crossbeam = "0.8"
indexmap = "2.2"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.5" rayon = "1.8"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
strum = "0.26.2"
strum_macros = "0.26.4"
[dependencies.inkwell] [dependencies.inkwell]
version = "0.2" version = "0.4"
default-features = false default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"] features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]

View File

@ -7,31 +7,41 @@ use std::{
process::{Command, Stdio}, process::{Command, Stdio},
}; };
fn main() { fn compile_irrt(irrt_dir: &Path, out_dir: &Path) {
const FILE: &str = "src/codegen/irrt/irrt.c"; let irrt_cpp_path = irrt_dir.join("irrt.cpp");
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
const FLAG: &[&str] = &[ let flags: &[&str] = &[
"--target=wasm32", "--target=wasm32",
FILE, irrt_cpp_path.to_str().unwrap(),
"-O3", "-x",
"c++",
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
match env::var("PROFILE").as_deref() {
Ok("debug") => "-O0",
Ok("release") => "-O3",
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
},
"-emit-llvm", "-emit-llvm",
"-S", "-S",
"-Wall", "-Wall",
"-Wextra", "-Wextra",
"-Werror=return-type",
"-I",
irrt_dir.to_str().unwrap(),
"-o", "-o",
"-", "-",
]; ];
println!("cargo:rerun-if-changed={FILE}"); println!("cargo:rerun-if-changed={}", out_dir.to_str().unwrap());
let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir);
let output = Command::new("clang-irrt") let output = Command::new("clang-irrt")
.args(FLAG) .args(flags)
.output() .output()
.map(|o| { .map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap()); assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
@ -43,7 +53,11 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n"); let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len()); let mut filtered_output = String::with_capacity(output.len());
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); // (?ms:^define.*?\}$) to capture `define` blocks
// (?m:^declare.*?$) to capture `declare` blocks
// (?m:^%.+?=\s*type\s*\{.+?\}$) to capture `type` declarations
let regex_filter =
Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)").unwrap();
for f in regex_filter.captures_iter(&output) { for f in regex_filter.captures_iter(&output) {
assert_eq!(f.len(), 1); assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]); filtered_output.push_str(&f[0]);
@ -56,18 +70,65 @@ fn main() {
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT"); println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
if env::var("DEBUG_DUMP_IRRT").is_ok() { if env::var("DEBUG_DUMP_IRRT").is_ok() {
let mut file = File::create(out_path.join("irrt.ll")).unwrap(); let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap(); file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap(); let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap();
} }
let mut llvm_as = Command::new("llvm-as-irrt") let mut llvm_as = Command::new("llvm-as-irrt")
.stdin(Stdio::piped()) .stdin(Stdio::piped())
.arg("-o") .arg("-o")
.arg(out_path.join("irrt.bc")) .arg(out_dir.join("irrt.bc"))
.spawn() .spawn()
.unwrap(); .unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success()); assert!(llvm_as.wait().unwrap().success());
} }
fn compile_irrt_test(irrt_dir: &Path, out_dir: &Path) {
let irrt_test_cpp_path = irrt_dir.join("irrt_test.cpp");
let exe_path = out_dir.join("irrt_test.out");
let flags: &[&str] = &[
irrt_test_cpp_path.to_str().unwrap(),
"-x",
"c++",
"-I",
irrt_dir.to_str().unwrap(),
"-g",
"-fno-discard-value-names",
"-O0",
"-Wall",
"-Wextra",
"-Werror=return-type",
"-lm", // for `tgamma()`, `lgamma()`
"-o",
exe_path.to_str().unwrap(),
];
Command::new("clang-irrt-test")
.args(flags)
.output()
.map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
o
})
.unwrap();
println!("cargo:rerun-if-changed={}", out_dir.to_str().unwrap());
}
fn main() {
let out_dir = env::var("OUT_DIR").unwrap();
let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("./irrt");
compile_irrt(irrt_dir, out_dir);
// https://github.com/rust-lang/cargo/issues/2549
// `cargo test -F test` to also build `irrt_test.cpp
if cfg!(feature = "test") {
compile_irrt_test(irrt_dir, out_dir);
}
}

5
nac3core/irrt/irrt.cpp Normal file
View File

@ -0,0 +1,5 @@
#include "irrt_everything.hpp"
/*
This file will be read by `clang-irrt` to conveniently produce LLVM IR for `nac3core/codegen`.
*/

437
nac3core/irrt/irrt.hpp Normal file
View File

@ -0,0 +1,437 @@
#ifndef IRRT_DONT_TYPEDEF_INTS
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
#endif
// NDArray indices are always `uint32_t`.
typedef uint32_t NDIndex;
// The type of an index or a value describing the length of a range/slice is
// always `int32_t`.
typedef int32_t SliceIndex;
template <typename T>
static T max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
static T min(T a, T b) {
return a > b ? b : a;
}
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template <typename T>
static T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
template <typename SizeT>
static SizeT __nac3_ndarray_calc_size_impl(
const SizeT *list_data,
SizeT list_len,
SizeT begin_idx,
SizeT end_idx
) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template <typename SizeT>
static void __nac3_ndarray_calc_nd_indices_impl(
SizeT index,
const SizeT *dims,
SizeT num_dims,
NDIndex *idxs
) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template <typename SizeT>
static SizeT __nac3_ndarray_flatten_index_impl(
const SizeT *dims,
SizeT num_dims,
const NDIndex *indices,
SizeT num_indices
) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template <typename SizeT>
static void __nac3_ndarray_calc_broadcast_impl(
const SizeT *lhs_dims,
SizeT lhs_ndims,
const SizeT *rhs_dims,
SizeT rhs_ndims,
SizeT *out_dims
) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template <typename SizeT>
static void __nac3_ndarray_calc_broadcast_idx_impl(
const SizeT *src_dims,
SizeT src_ndims,
const NDIndex *in_idx,
NDIndex *out_idx
) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
template<typename SizeT>
static void __nac3_ndarray_strides_from_shape_impl(
SizeT ndims,
SizeT *shape,
SizeT *dst_strides
) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndims; i++) {
int dim_i = ndims - i - 1;
dst_strides[dim_i] = stride_product;
stride_product *= shape[dim_i];
}
}
extern "C" {
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) {\
return __nac3_int_exp_impl(base, exp);\
}
DEF_nac3_int_exp_(int32_t)
DEF_nac3_int_exp_(int64_t)
DEF_nac3_int_exp_(uint32_t)
DEF_nac3_int_exp_(uint64_t)
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(
const SliceIndex start,
const SliceIndex end,
const SliceIndex step
) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(
SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t *dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t *src_arr,
SliceIndex src_arr_len,
const SliceIndex size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
uint32_t __nac3_ndarray_calc_size(
const uint32_t *list_data,
uint32_t list_len,
uint32_t begin_idx,
uint32_t end_idx
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t __nac3_ndarray_calc_size64(
const uint64_t *list_data,
uint64_t list_len,
uint64_t begin_idx,
uint64_t end_idx
) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(
uint32_t index,
const uint32_t* dims,
uint32_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(
uint64_t index,
const uint64_t* dims,
uint64_t num_dims,
NDIndex* idxs
) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t __nac3_ndarray_flatten_index(
const uint32_t* dims,
uint32_t num_dims,
const NDIndex* indices,
uint32_t num_indices
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t __nac3_ndarray_flatten_index64(
const uint64_t* dims,
uint64_t num_dims,
const NDIndex* indices,
uint64_t num_indices
) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(
const uint32_t *lhs_dims,
uint32_t lhs_ndims,
const uint32_t *rhs_dims,
uint32_t rhs_ndims,
uint32_t *out_dims
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(
const uint64_t *lhs_dims,
uint64_t lhs_ndims,
const uint64_t *rhs_dims,
uint64_t rhs_ndims,
uint64_t *out_dims
) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(
const uint32_t *src_dims,
uint32_t src_ndims,
const NDIndex *in_idx,
NDIndex *out_idx
) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(
const uint64_t *src_dims,
uint64_t src_ndims,
const NDIndex *in_idx,
NDIndex *out_idx
) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_strides_from_shape(uint32_t ndims, uint32_t* shape, uint32_t* dst_strides) {
__nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides);
}
void __nac3_ndarray_strides_from_shape64(uint64_t ndims, uint64_t* shape, uint64_t* dst_strides) {
__nac3_ndarray_strides_from_shape_impl(ndims, shape, dst_strides);
}
}

View File

@ -0,0 +1,216 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
/*
This header contains IRRT implementations
that do not deserved to be categorized (e.g., into numpy, etc.)
Check out other *.hpp files before including them here!!
*/
// The type of an index or a value describing the length of a range/slice is
// always `int32_t`.
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template <typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
}
extern "C" {
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) {\
return __nac3_int_exp_impl(base, exp);\
}
DEF_nac3_int_exp_(int32_t)
DEF_nac3_int_exp_(int64_t)
DEF_nac3_int_exp_(uint32_t)
DEF_nac3_int_exp_(uint64_t)
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(
const SliceIndex start,
const SliceIndex end,
const SliceIndex step
) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(
SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t *dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t *src_arr,
SliceIndex src_arr_len,
const SliceIndex size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
}

View File

@ -0,0 +1,14 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
#include "irrt_basic.hpp"
#include "irrt_slice.hpp"
#include "irrt_numpy_ndarray.hpp"
/*
All IRRT implementations.
We don't have any pre-compiled objects, so we are writing all implementations in headers and
concatenate them with `#include` into one massive source file that contains all the IRRT stuff.
*/

View File

@ -0,0 +1,466 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
#include "irrt_slice.hpp"
/*
NDArray-related implementations.
`*/
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
namespace {
namespace ndarray_util {
template <typename SizeT>
static void set_indices_by_nth(SizeT ndims, const SizeT* shape, SizeT* indices, SizeT nth) {
for (int32_t i = 0; i < ndims; i++) {
int32_t dim_i = ndims - i - 1;
int32_t dim = shape[dim_i];
indices[dim_i] = nth % dim;
nth /= dim;
}
}
// Compute the strides of an ndarray given an ndarray `shape`
// and assuming that the ndarray is *fully C-contagious*.
//
// You might want to read up on https://ajcr.net/stride-guide-part-1/.
template <typename SizeT>
static void set_strides_by_shape(SizeT itemsize, SizeT ndims, SizeT* dst_strides, const SizeT* shape) {
SizeT stride_product = 1;
for (SizeT i = 0; i < ndims; i++) {
int dim_i = ndims - i - 1;
dst_strides[dim_i] = stride_product * itemsize;
stride_product *= shape[dim_i];
}
}
// Compute the size/# of elements of an ndarray given its shape
template <typename SizeT>
static SizeT calc_size_from_shape(SizeT ndims, const SizeT* shape) {
SizeT size = 1;
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) size *= shape[dim_i];
return size;
}
template <typename SizeT>
static bool can_broadcast_shape_to(
const SizeT target_ndims,
const SizeT *target_shape,
const SizeT src_ndims,
const SizeT *src_shape
) {
/*
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
This function handles this example:
```
Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3
```
Other interesting examples to consider:
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
- `can_broadcast_shape_to([3], [3, 1]) == false`
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
In cases when the shapes contain zero(es):
- `can_broadcast_shape_to([0], [1]) == true`
- `can_broadcast_shape_to([0], [2]) == false`
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
*/
// This is essentially doing the following in Python:
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
SizeT target_dim_i = target_ndims - i - 1;
SizeT src_dim_i = src_ndims - i - 1;
bool target_dim_exists = target_dim_i >= 0;
bool src_dim_exists = src_dim_i >= 0;
SizeT target_dim = target_dim_exists ? target_shape[target_dim_i] : 1;
SizeT src_dim = src_dim_exists ? src_shape[src_dim_i] : 1;
bool ok = src_dim == 1 || target_dim == src_dim;
if (!ok) return false;
}
return true;
}
}
typedef uint8_t NDSliceType;
extern "C" {
const NDSliceType INPUT_SLICE_TYPE_INDEX = 0;
const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
}
struct NDSlice {
// A poor-man's `std::variant<int, UserRange>`
NDSliceType type;
/*
if type == INPUT_SLICE_TYPE_INDEX => `slice` points to a single `SizeT`
if type == INPUT_SLICE_TYPE_SLICE => `slice` points to a single `UserRange`
*/
uint8_t *slice;
};
namespace ndarray_util {
template<typename SizeT>
SizeT deduce_ndims_after_slicing(SizeT ndims, SizeT num_slices, const NDSlice *slices) {
irrt_assert(num_slices <= ndims);
SizeT final_ndims = ndims;
for (SizeT i = 0; i < num_slices; i++) {
if (slices[i].type == INPUT_SLICE_TYPE_INDEX) {
final_ndims--; // An integer slice demotes the rank by 1
}
}
return final_ndims;
}
}
template <typename SizeT>
struct NDArrayIndicesIter {
SizeT ndims;
const SizeT *shape;
SizeT *indices;
void set_indices_zero() {
__builtin_memset(indices, 0, sizeof(SizeT) * ndims);
}
void next() {
for (SizeT i = 0; i < ndims; i++) {
SizeT dim_i = ndims - i - 1;
indices[dim_i]++;
if (indices[dim_i] < shape[dim_i]) {
break;
} else {
indices[dim_i] = 0;
}
}
}
};
// The NDArray object. `SizeT` is the *signed* size type of this ndarray.
//
// NOTE: The order of fields is IMPORTANT. DON'T TOUCH IT
//
// Some resources you might find helpful:
// - The official numpy implementations:
// - https://github.com/numpy/numpy/blob/735a477f0bc2b5b84d0e72d92f224bde78d4e069/doc/source/reference/c-api/types-and-structures.rst
// - On strides (about reshaping, slicing, C-contagiousness, etc)
// - https://ajcr.net/stride-guide-part-1/.
// - https://ajcr.net/stride-guide-part-2/.
// - https://ajcr.net/stride-guide-part-3/.
template <typename SizeT>
struct NDArray {
// The underlying data this `ndarray` is pointing to.
//
// NOTE: Formally this should be of type `void *`, but clang
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
// so we will put `uint8_t *` here for clarity.
uint8_t *data;
// The number of bytes of a single element in `data`.
//
// The `SizeT` is treated as `unsigned`.
SizeT itemsize;
// The number of dimensions of this shape.
//
// The `SizeT` is treated as `unsigned`.
SizeT ndims;
// Array shape, with length equal to `ndims`.
//
// The `SizeT` is treated as `unsigned`.
//
// NOTE: `shape` can contain 0.
// (those appear when the user makes an out of bounds slice into an ndarray, e.g., `np.zeros((3, 3))[400:].shape == (0, 3)`)
SizeT *shape;
// Array strides (stride value is in number of bytes, NOT number of elements), with length equal to `ndims`.
//
// The `SizeT` is treated as `signed`.
//
// NOTE: `strides` can have negative numbers.
// (those appear when there is a slice with a negative step, e.g., `my_array[::-1]`)
SizeT *strides;
// Calculate the size/# of elements of an `ndarray`.
// This function corresponds to `np.size(<ndarray>)` or `ndarray.size`
SizeT size() {
return ndarray_util::calc_size_from_shape(ndims, shape);
}
// Calculate the number of bytes of its content of an `ndarray` *in its view*.
// This function corresponds to `ndarray.nbytes`
SizeT nbytes() {
return this->size() * itemsize;
}
void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, itemsize);
}
uint8_t* get_pelement(const SizeT *indices) {
uint8_t* element = data;
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
element += indices[dim_i] * strides[dim_i];
return element;
}
uint8_t* get_nth_pelement(SizeT nth) {
irrt_assert(0 <= nth);
irrt_assert(nth < this->size());
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
return get_pelement(indices);
}
// Get pointer to the first element of this ndarray, assuming
// `this->size() > 0`, i.e., not "degenerate" due to zeroes in `this->shape`)
//
// This is particularly useful for when the ndarray is just containing a single scalar.
uint8_t* get_first_pelement() {
irrt_assert(this->size() > 0);
return this->data; // ...It is simply `this->data`
}
// Is the given `indices` valid/in-bounds?
bool in_bounds(const SizeT *indices) {
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) {
bool dim_ok = indices[dim_i] < shape[dim_i];
if (!dim_ok) return false;
}
return true;
}
// Fill the ndarray with a value
void fill_generic(const uint8_t* pvalue) {
NDArrayIndicesIter<SizeT> iter;
iter.ndims = this->ndims;
iter.shape = this->shape;
iter.indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndims);
iter.set_indices_zero();
for (SizeT i = 0; i < this->size(); i++, iter.next()) {
uint8_t* pelement = get_pelement(iter.indices);
set_value_at_pelement(pelement, pvalue);
}
}
// Set the strides of the ndarray with `ndarray_util::set_strides_by_shape`
void set_strides_by_shape() {
ndarray_util::set_strides_by_shape(itemsize, ndims, strides, shape);
}
// https://numpy.org/doc/stable/reference/generated/numpy.eye.html
void set_to_eye(SizeT k, const uint8_t* zero_pvalue, const uint8_t* one_pvalue) {
__builtin_assume(ndims == 2);
// TODO: Better implementation
fill_generic(zero_pvalue);
for (SizeT i = 0; i < min(shape[0], shape[1]); i++) {
SizeT row = i;
SizeT col = i + k;
SizeT indices[2] = { row, col };
if (!in_bounds(indices)) continue;
uint8_t* pelement = get_pelement(indices);
set_value_at_pelement(pelement, one_pvalue);
}
}
// To support numpy complex slices (e.g., `my_array[:50:2,4,:2:-1]`)
//
// Things assumed by this function:
// - `dst_ndarray` is allocated by the caller
// - `dst_ndarray.ndims` has the correct value (according to `ndarray_util::deduce_ndims_after_slicing`).
// - ... and `dst_ndarray.shape` and `dst_ndarray.strides` have been allocated by the caller as well
//
// Other notes:
// - `dst_ndarray->data` does not have to be set, it will be derived.
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->itemsize`
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
void slice(SizeT num_ndslices, NDSlice* ndslices, NDArray<SizeT>* dst_ndarray) {
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
dst_ndarray->data = this->data;
SizeT this_axis = 0;
SizeT dst_axis = 0;
for (SizeT i = 0; i < num_ndslices; i++) {
NDSlice *ndslice = &ndslices[i];
if (ndslice->type == INPUT_SLICE_TYPE_INDEX) {
// Handle when the ndslice is just a single (possibly negative) integer
// e.g., `my_array[::2, -5, ::-1]`
// ^^------ like this
SizeT index_user = *((SizeT*) ndslice->slice);
SizeT index = resolve_index_in_length(this->shape[this_axis], index_user);
dst_ndarray->data += index * this->strides[this_axis]; // Add offset
// Next
this_axis++;
} else if (ndslice->type == INPUT_SLICE_TYPE_SLICE) {
// Handle when the ndslice is a slice (represented by UserSlice in IRRT)
// e.g., `my_array[::2, -5, ::-1]`
// ^^^------^^^^----- like these
UserSlice<SizeT>* user_slice = (UserSlice<SizeT>*) ndslice->slice;
Slice<SizeT> slice = user_slice->indices(this->shape[this_axis]); // To resolve negative indices and other funny stuff written by the user
// NOTE: There is no need to write special code to handle negative steps/strides.
// This simple implementation meticulously handles both positive and negative steps/strides.
// Check out the tinynumpy and IRRT's test cases if you are not convinced.
dst_ndarray->data += slice.start * this->strides[this_axis]; // Add offset (NOTE: no need to `* itemsize`, strides count in # of bytes)
dst_ndarray->strides[dst_axis] = slice.step * this->strides[this_axis]; // Determine stride
dst_ndarray->shape[dst_axis] = slice.len(); // Determine shape dimension
// Next
dst_axis++;
this_axis++;
} else {
__builtin_unreachable();
}
}
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
}
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
// Assumptions:
// - `this` has to be fully initialized.
// - `dst_ndarray->ndims` has to be set.
// - `dst_ndarray->shape` has to be set, this determines the shape `this` broadcasts to.
//
// Other notes:
// - `dst_ndarray->data` does not have to be set, it will be set to `this->data`.
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `this->data`.
// - `dst_ndarray->strides` does not have to be set, it will be overwritten.
//
// Cautions:
// ```
// xs = np.zeros((4,))
// ys = np.zero((4, 1))
// ys[:] = xs # ok
//
// xs = np.zeros((1, 4))
// ys = np.zero((4,))
// ys[:] = xs # allowed
// # However `np.broadcast_to(xs, (4,))` would fails, as per numpy's broadcasting rule.
// # and apparently numpy will "deprecate" this? SEE https://github.com/numpy/numpy/issues/21744
// # This implementation will NOT support this assignment.
// ```
void broadcast_to(NDArray<SizeT>* dst_ndarray) {
dst_ndarray->data = this->data;
dst_ndarray->itemsize = this->itemsize;
irrt_assert(
ndarray_util::can_broadcast_shape_to(
dst_ndarray->ndims,
dst_ndarray->shape,
this->ndims,
this->shape
)
);
SizeT stride_product = 1;
for (SizeT i = 0; i < max(this->ndims, dst_ndarray->ndims); i++) {
SizeT this_dim_i = this->ndims - i - 1;
SizeT dst_dim_i = dst_ndarray->ndims - i - 1;
bool this_dim_exists = this_dim_i >= 0;
bool dst_dim_exists = dst_dim_i >= 0;
// TODO: Explain how this works
bool c1 = this_dim_exists && this->shape[this_dim_i] == 1;
bool c2 = dst_dim_exists && dst_ndarray->shape[dst_dim_i] != 1;
if (!this_dim_exists || (c1 && c2)) {
dst_ndarray->strides[dst_dim_i] = 0; // Freeze it in-place
} else {
dst_ndarray->strides[dst_dim_i] = stride_product * this->itemsize;
stride_product *= this->shape[this_dim_i]; // NOTE: this_dim_exist must be true here.
}
}
}
// Simulates `this_ndarray[:] = src_ndarray`, with automatic broadcasting.
// Caution on https://github.com/numpy/numpy/issues/21744
// Also see `NDArray::broadcast_to`
void assign_with(NDArray<SizeT>* src_ndarray) {
irrt_assert(
ndarray_util::can_broadcast_shape_to(
this->ndims,
this->shape,
src_ndarray->ndims,
src_ndarray->shape
)
);
// Broadcast the `src_ndarray` to make the reading process *much* easier
SizeT* broadcasted_src_ndarray_strides = __builtin_alloca(sizeof(SizeT) * this->ndims); // Remember to allocate strides beforehand
NDArray<SizeT> broadcasted_src_ndarray = {
.ndims = this->ndims,
.shape = this->shape,
.strides = broadcasted_src_ndarray_strides
};
src_ndarray->broadcast_to(&broadcasted_src_ndarray);
// Using iter instead of `get_nth_pelement` because it is slightly faster
SizeT* indices = __builtin_alloca(sizeof(SizeT) * this->ndims);
auto iter = NDArrayIndicesIter<SizeT> {
.ndims = this->ndims,
.shape = this->shape,
.indices = indices
};
const SizeT this_size = this->size();
for (SizeT i = 0; i < this_size; i++, iter.next()) {
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement(indices);
uint8_t* this_pelement = this->get_pelement(indices);
this->set_value_at_pelement(src_pelement, src_pelement);
}
}
};
}
extern "C" {
uint32_t __nac3_ndarray_size(NDArray<int32_t>* ndarray) {
return ndarray->size();
}
uint64_t __nac3_ndarray_size64(NDArray<int64_t>* ndarray) {
return ndarray->size();
}
void __nac3_ndarray_fill_generic(NDArray<int32_t>* ndarray, uint8_t* pvalue) {
ndarray->fill_generic(pvalue);
}
void __nac3_ndarray_fill_generic64(NDArray<int64_t>* ndarray, uint8_t* pvalue) {
ndarray->fill_generic(pvalue);
}
// void __nac3_ndarray_slice(NDArray<int32_t>* ndarray, int32_t num_slices, NDSlice<int32_t> *slices, NDArray<int32_t> *dst_ndarray) {
// // ndarray->slice(num_slices, slices, dst_ndarray);
// }
}

View File

@ -0,0 +1,80 @@
#pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
namespace {
// A proper slice in IRRT, all negative indices have be resolved to absolute values.
// Even though nac3core's slices are always `int32_t`, we will template slice anyway
// since this struct is used as a general utility.
template <typename T>
struct Slice {
T start;
T stop;
T step;
// The length/The number of elements of the slice if it were a range,
// i.e., the value of `len(range(this->start, this->stop, this->end))`
T len() {
T diff = stop - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
};
template<typename T>
T resolve_index_in_length(T length, T index) {
irrt_assert(length >= 0);
if (index < 0) {
// Remember that index is negative, so do a plus here
return max(length + index, 0);
} else {
return min(length, index);
}
}
// NOTE: using a bitfield for the `*_defined` is better, at the
// cost of a more annoying implementation in nac3core inkwell
template <typename T>
struct UserSlice {
uint8_t start_defined;
T start;
uint8_t stop_defined;
T stop;
uint8_t step_defined;
T step;
// Like Python's `slice(start, stop, step).indices(length)`
Slice<T> indices(T length) {
// NOTE: This function implements Python's `slice.indices` *FAITHFULLY*.
// SEE: https://github.com/python/cpython/blob/f62161837e68c1c77961435f1b954412dd5c2b65/Objects/sliceobject.c#L546
irrt_assert(length >= 0);
irrt_assert(!step_defined || step != 0); // step_defined -> step != 0; step cannot be zero if specified by user
Slice<T> result;
result.step = step_defined ? step : 1;
bool step_is_negative = result.step < 0;
if (start_defined) {
result.start = resolve_index_in_length(length, start);
} else {
result.start = step_is_negative ? length - 1 : 0;
}
if (stop_defined) {
result.stop = resolve_index_in_length(length, stop);
} else {
result.stop = step_is_negative ? -1 : length;
}
return result;
}
};
}

658
nac3core/irrt/irrt_test.cpp Normal file
View File

@ -0,0 +1,658 @@
// This file will be compiled like a real C++ program,
// and we do have the luxury to use the standard libraries.
// That is if the nix flakes do not have issues... especially on msys2...
#include <cstdint>
#include <cstdio>
#include <cstdlib>
// Set `IRRT_DONT_TYPEDEF_INTS` because `cstdint` defines them
#define IRRT_DONT_TYPEDEF_INTS
#include "irrt_everything.hpp"
void test_fail() {
printf("[!] Test failed\n");
exit(1);
}
void __begin_test(const char* function_name, const char* file, int line) {
printf("######### Running %s @ %s:%d\n", function_name, file, line);
}
#define BEGIN_TEST() __begin_test(__FUNCTION__, __FILE__, __LINE__)
template <typename T>
void debug_print_array(const char* format, int len, T* as) {
printf("[");
for (int i = 0; i < len; i++) {
if (i != 0) printf(", ");
printf(format, as[i]);
}
printf("]");
}
template <typename T>
void assert_arrays_match(const char* label, const char* format, int len, T* expected, T* got) {
if (!arrays_match(len, expected, got)) {
printf(">>>>>>> %s\n", label);
printf(" Expecting = ");
debug_print_array(format, len, expected);
printf("\n");
printf(" Got = ");
debug_print_array(format, len, got);
printf("\n");
test_fail();
}
}
template <typename T>
void assert_values_match(const char* label, const char* format, T expected, T got) {
if (expected != got) {
printf(">>>>>>> %s\n", label);
printf(" Expecting = ");
printf(format, expected);
printf("\n");
printf(" Got = ");
printf(format, got);
printf("\n");
test_fail();
}
}
void print_repeated(const char *str, int count) {
for (int i = 0; i < count; i++) {
printf("%s", str);
}
}
template<typename SizeT, typename ElementT>
void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* cursor, SizeT depth, NDArray<SizeT>* ndarray) {
// A really lazy recursive implementation
// Add left padding unless its the first entry (since there would be "[[[" before it)
if (!first) {
print_repeated(" ", depth);
}
const SizeT dim = ndarray->shape[depth];
if (depth + 1 == ndarray->ndims) {
// Recursed down to last dimension, print the values in a nice list
printf("[");
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
for (SizeT i = 0; i < dim; i++) {
ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
ElementT* pelement = (ElementT*) ndarray->get_pelement(indices);
ElementT element = *pelement;
if (i != 0) printf(", "); // List delimiter
printf(format, element);
printf("(@");
debug_print_array("%d", ndarray->ndims, indices);
printf(")");
(*cursor)++;
}
printf("]");
} else {
printf("[");
for (SizeT i = 0; i < ndarray->shape[depth]; i++) {
__print_ndarray_aux<SizeT, ElementT>(
format,
i == 0, // first?
i + 1 == dim, // last?
cursor,
depth + 1,
ndarray
);
}
printf("]");
}
// Add newline unless its the last entry (since there will be "]]]" after it)
if (!last) {
print_repeated("\n", depth);
}
}
template<typename SizeT, typename ElementT>
void print_ndarray(const char *format, NDArray<SizeT>* ndarray) {
if (ndarray->ndims == 0) {
printf("<empty ndarray>");
} else {
SizeT cursor = 0;
__print_ndarray_aux<SizeT, ElementT>(format, true, true, &cursor, 0, ndarray);
}
printf("\n");
}
void test_calc_size_from_shape_normal() {
// Test shapes with normal values
BEGIN_TEST();
int32_t shape[4] = { 2, 3, 5, 7 };
assert_values_match("size", "%d", 210, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
}
void test_calc_size_from_shape_has_zero() {
// Test shapes with 0 in them
BEGIN_TEST();
int32_t shape[4] = { 2, 0, 5, 7 };
assert_values_match("size", "%d", 0, ndarray_util::calc_size_from_shape<int32_t>(4, shape));
}
void test_set_strides_by_shape() {
// Test `set_strides_by_shape()`
BEGIN_TEST();
int32_t shape[4] = { 99, 3, 5, 7 };
int32_t strides[4] = { 0 };
ndarray_util::set_strides_by_shape((int32_t) sizeof(int32_t), 4, strides, shape);
int32_t expected_strides[4] = {
105 * sizeof(int32_t),
35 * sizeof(int32_t),
7 * sizeof(int32_t),
1 * sizeof(int32_t)
};
assert_arrays_match("strides", "%u", 4u, expected_strides, strides);
}
void test_ndarray_indices_iter_normal() {
// Test NDArrayIndicesIter normal behavior
BEGIN_TEST();
int32_t shape[3] = { 1, 2, 3 };
int32_t indices[3] = { 0, 0, 0 };
auto iter = NDArrayIndicesIter<int32_t> {
.ndims = 3,
.shape = shape,
.indices = indices
};
assert_arrays_match("indices #0", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 });
iter.next();
assert_arrays_match("indices #1", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
iter.next();
assert_arrays_match("indices #2", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 2 });
iter.next();
assert_arrays_match("indices #3", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 0 });
iter.next();
assert_arrays_match("indices #4", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 1 });
iter.next();
assert_arrays_match("indices #5", "%u", 3u, iter.indices, (int32_t[3]) { 0, 1, 2 });
iter.next();
assert_arrays_match("indices #6", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 0 }); // Loops back
iter.next();
assert_arrays_match("indices #7", "%u", 3u, iter.indices, (int32_t[3]) { 0, 0, 1 });
}
void test_ndarray_fill_generic() {
// Test ndarray fill_generic
BEGIN_TEST();
// Choose a type that's neither int32_t nor uint64_t (candidates of SizeT) to spice it up
// Also make all the octets non-zero, to see if `memcpy` in `fill_generic` is working perfectly.
uint16_t fill_value = 0xFACE;
uint16_t in_data[6] = { 100, 101, 102, 103, 104, 105 }; // Fill `data` with values that != `999`
int32_t in_itemsize = sizeof(uint16_t);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 2, 3 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides,
};
ndarray.set_strides_by_shape();
ndarray.fill_generic((uint8_t*) &fill_value); // `fill_generic` here
uint16_t expected_data[6] = { fill_value, fill_value, fill_value, fill_value, fill_value, fill_value };
assert_arrays_match("data", "0x%hX", 6, expected_data, in_data);
}
void test_ndarray_set_to_eye() {
// Test `set_to_eye` behavior (helper function to implement `np.eye()`)
BEGIN_TEST();
double in_data[9] = { 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0, 99.0 };
int32_t in_itemsize = sizeof(double);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 3, 3 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides,
};
ndarray.set_strides_by_shape();
double zero = 0.0;
double one = 1.0;
ndarray.set_to_eye(1, (uint8_t*) &zero, (uint8_t*) &one);
assert_values_match("in_data[0]", "%f", 0.0, in_data[0]);
assert_values_match("in_data[1]", "%f", 1.0, in_data[1]);
assert_values_match("in_data[2]", "%f", 0.0, in_data[2]);
assert_values_match("in_data[3]", "%f", 0.0, in_data[3]);
assert_values_match("in_data[4]", "%f", 0.0, in_data[4]);
assert_values_match("in_data[5]", "%f", 1.0, in_data[5]);
assert_values_match("in_data[6]", "%f", 0.0, in_data[6]);
assert_values_match("in_data[7]", "%f", 0.0, in_data[7]);
assert_values_match("in_data[8]", "%f", 0.0, in_data[8]);
}
void test_slice_1() {
// Test `slice(5, None, None).indices(100) == slice(5, 100, 1)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 1,
.start = 5,
.stop_defined = 0,
.step_defined = 0,
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 5, slice.start);
assert_values_match("stop", "%d", 100, slice.stop);
assert_values_match("step", "%d", 1, slice.step);
}
void test_slice_2() {
// Test `slice(400, 999, None).indices(100) == slice(100, 100, 1)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 1,
.start = 400,
.stop_defined = 0,
.step_defined = 0,
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 100, slice.start);
assert_values_match("stop", "%d", 100, slice.stop);
assert_values_match("step", "%d", 1, slice.step);
}
void test_slice_3() {
// Test `slice(-10, -5, None).indices(100) == slice(90, 95, 1)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 1,
.start = -10,
.stop_defined = 1,
.stop = -5,
.step_defined = 0,
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 90, slice.start);
assert_values_match("stop", "%d", 95, slice.stop);
assert_values_match("step", "%d", 1, slice.step);
}
void test_slice_4() {
// Test `slice(None, None, -5).indices(100) == (99, -1, -5)`
BEGIN_TEST();
UserSlice<int> user_slice = {
.start_defined = 0,
.stop_defined = 0,
.step_defined = 1,
.step = -5
};
auto slice = user_slice.indices(100);
assert_values_match("start", "%d", 99, slice.start);
assert_values_match("stop", "%d", -1, slice.stop);
assert_values_match("step", "%d", -5, slice.step);
}
void test_ndslice_1() {
/*
Reference Python code:
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4));
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[-2:, 1::2]
# array([[ 5., 7.],
# [ 9., 11.]])
assert dst_ndarray.shape == (2, 2)
assert dst_ndarray.strides == (32, 16)
assert dst_ndarray[0, 0] == 5.0
assert dst_ndarray[0, 1] == 7.0
assert dst_ndarray[1, 0] == 9.0
assert dst_ndarray[1, 1] == 11.0
```
*/
BEGIN_TEST();
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
int32_t in_itemsize = sizeof(double);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 3, 4 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides
};
ndarray.set_strides_by_shape();
// Destination ndarray
// As documented, ndims and shape & strides must be allocated and determined by the caller.
const int32_t dst_ndims = 2;
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int32_t> dst_ndarray = {
.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Create the slice in `ndarray[-2::, 1::2]`
UserSlice<int32_t> user_slice_1 = {
.start_defined = 1,
.start = -2,
.stop_defined = 0,
.step_defined = 0
};
UserSlice<int32_t> user_slice_2 = {
.start_defined = 1,
.start = 1,
.stop_defined = 0,
.step_defined = 1,
.step = 2
};
const int32_t num_ndslices = 2;
NDSlice ndslices[num_ndslices] = {
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 },
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
};
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
int32_t expected_shape[dst_ndims] = { 2, 2 };
int32_t expected_strides[dst_ndims] = { 32, 16 };
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 0 })));
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 1 })));
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 0 })));
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 1 })));
}
void test_ndslice_2() {
/*
```python
ndarray = np.arange(12, dtype=np.float64).reshape((3, 4))
# array([[ 0., 1., 2., 3.],
# [ 4., 5., 6., 7.],
# [ 8., 9., 10., 11.]])
dst_ndarray = ndarray[2, ::-2]
# array([11., 9.])
assert dst_ndarray.shape == (2,)
assert dst_ndarray.strides == (-16,)
assert dst_ndarray[0] == 11.0
assert dst_ndarray[1] == 9.0
dst_ndarray[1, 0] == 99 # If you write to `dst_ndarray`
assert ndarray[1, 3] == 99 # `ndarray` also updates!!
```
*/
BEGIN_TEST();
double in_data[12] = { 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0 };
int32_t in_itemsize = sizeof(double);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 3, 4 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides
};
ndarray.set_strides_by_shape();
// Destination ndarray
// As documented, ndims and shape & strides must be allocated and determined by the caller.
const int32_t dst_ndims = 1;
int32_t dst_shape[dst_ndims] = {999}; // Empty values
int32_t dst_strides[dst_ndims] = {999}; // Empty values
NDArray<int32_t> dst_ndarray = {
.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Create the slice in `ndarray[2, ::-2]`
int32_t user_slice_1 = 2;
UserSlice<int32_t> user_slice_2 = {
.start_defined = 0,
.stop_defined = 0,
.step_defined = 1,
.step = -2
};
const int32_t num_ndslices = 2;
NDSlice ndslices[num_ndslices] = {
{ .type = INPUT_SLICE_TYPE_INDEX, .slice = (uint8_t*) &user_slice_1 },
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_2 }
};
ndarray.slice(num_ndslices, ndslices, &dst_ndarray);
int32_t expected_shape[dst_ndims] = { 2 };
int32_t expected_strides[dst_ndims] = { -16 };
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
// [5.0, 3.0]
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0 })));
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 })));
}
void test_can_broadcast_shape() {
BEGIN_TEST();
assert_values_match(
"can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 5, (int32_t[]) { 1, 1, 1, 1, 3 })
);
assert_values_match(
"can_broadcast_shape_to([3], [3, 1]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 2, (int32_t[]) { 3, 1 }));
assert_values_match(
"can_broadcast_shape_to([3], [3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 3 }, 1, (int32_t[]) { 3 }));
assert_values_match(
"can_broadcast_shape_to([1], [3]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 3 }));
assert_values_match(
"can_broadcast_shape_to([1], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 1 }, 1, (int32_t[]) { 1 }));
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 3, (int32_t[]) { 256, 1, 3 })
);
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [3]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 3 })
);
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [2]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 2 })
);
assert_values_match(
"can_broadcast_shape_to([256, 256, 3], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(3, (int32_t[]) { 256, 256, 3 }, 1, (int32_t[]) { 1 })
);
// In cases when the shapes contain zero(es)
assert_values_match(
"can_broadcast_shape_to([0], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 1 })
);
assert_values_match(
"can_broadcast_shape_to([0], [2]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(1, (int32_t[]) { 0 }, 1, (int32_t[]) { 2 })
);
assert_values_match(
"can_broadcast_shape_to([0, 4, 0, 0], [1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 1, (int32_t[]) { 1 })
);
assert_values_match(
"can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 1, 1, 1 })
);
assert_values_match(
"can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true",
"%d",
true,
ndarray_util::can_broadcast_shape_to(4, (int32_t[]) { 0, 4, 0, 0 }, 4, (int32_t[]) { 1, 4, 1, 1 })
);
assert_values_match(
"can_broadcast_shape_to([4, 3], [0, 3]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 3 })
);
assert_values_match(
"can_broadcast_shape_to([4, 3], [0, 0]) == false",
"%d",
false,
ndarray_util::can_broadcast_shape_to(2, (int32_t[]) { 4, 3 }, 2, (int32_t[]) { 0, 0 })
);
}
void test_ndarray_broadcast_1() {
/*
# array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
# >>> [[19.9 29.9 39.9 49.9]]
#
# array = np.broadcast_to(array, (2, 3, 4))
# >>> [[[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]
# >>> [[19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]
# >>> [19.9 29.9 39.9 49.9]]]
#
# assery array.strides == (0, 0, 8)
*/
BEGIN_TEST();
double in_data[4] = { 19.9, 29.9, 39.9, 49.9 };
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = {1, 4};
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = sizeof(double),
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides
};
ndarray.set_strides_by_shape();
const int32_t dst_ndims = 3;
int32_t dst_shape[dst_ndims] = {2, 3, 4};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
ndarray.broadcast_to(&dst_ndarray);
assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 0})));
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 1})));
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 2})));
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 3})));
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 0})));
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 1})));
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 2})));
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 3})));
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3})));
}
void test_assign_with() {
/*
```
xs = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float64)
ys = xs.shape
```
*/
}
int main() {
test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero();
test_set_strides_by_shape();
test_ndarray_indices_iter_normal();
test_ndarray_fill_generic();
test_ndarray_set_to_eye();
test_slice_1();
test_slice_2();
test_slice_3();
test_slice_4();
test_ndslice_1();
test_ndslice_2();
test_can_broadcast_shape();
test_ndarray_broadcast_1();
test_assign_with();
return 0;
}

View File

@ -0,0 +1,14 @@
#pragma once
// This is made toggleable since `irrt_test.cpp` itself would include
// headers that define the `int_t` family.
#ifndef IRRT_DONT_TYPEDEF_INTS
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
#endif
typedef int32_t SliceIndex;

View File

@ -0,0 +1,37 @@
#pragma once
#include "irrt_typedefs.hpp"
namespace {
template <typename T>
T max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
T min(T a, T b) {
return a > b ? b : a;
}
template <typename T>
bool arrays_match(int len, T *as, T *bs) {
for (int i = 0; i < len; i++) {
if (as[i] != bs[i]) return false;
}
return true;
}
void irrt_panic() {
// Crash the program for now.
// TODO: Don't crash the program
// ... or at least produce a good message when doing testing IRRT
uint8_t* death = nullptr;
*death = 0; // TODO: address 0 on hardware might be writable?
}
// TODO: Make this a macro and allow it to be toggled on/off (e.g., debug vs release)
void irrt_assert(bool condition) {
if (!condition) irrt_panic();
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -3,10 +3,13 @@ use crate::{
toplevel::DefinitionId, toplevel::DefinitionId,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{
into_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier,
},
}, },
}; };
use indexmap::IndexMap;
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use std::collections::HashMap; use std::collections::HashMap;
@ -44,13 +47,10 @@ pub enum ConcreteTypeEnum {
TTuple { TTuple {
ty: Vec<ConcreteType>, ty: Vec<ConcreteType>,
}, },
TList {
ty: ConcreteType,
},
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>, fields: HashMap<StrRef, (ConcreteType, bool)>,
params: HashMap<u32, ConcreteType>, params: IndexMap<TypeVarId, ConcreteType>,
}, },
TVirtual { TVirtual {
ty: ConcreteType, ty: ConcreteType,
@ -58,11 +58,10 @@ pub enum ConcreteTypeEnum {
TFunc { TFunc {
args: Vec<ConcreteFuncArg>, args: Vec<ConcreteFuncArg>,
ret: ConcreteType, ret: ConcreteType,
vars: HashMap<u32, ConcreteType>, vars: HashMap<TypeVarId, ConcreteType>,
}, },
TConstant { TLiteral {
value: SymbolValue, values: Vec<SymbolValue>,
ty: ConcreteType,
}, },
} }
@ -165,9 +164,6 @@ impl ConcreteTypeStore {
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache)) .map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(), .collect(),
}, },
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
fields: fields fields: fields
@ -202,10 +198,9 @@ impl ConcreteTypeStore {
TypeEnum::TFunc(signature) => { TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache) self.from_signature(unifier, primitives, signature, cache)
} }
TypeEnum::TConstant { value, ty, .. } => ConcreteTypeEnum::TConstant { TypeEnum::TLiteral { values, .. } => {
value: value.clone(), ConcreteTypeEnum::TLiteral { values: values.clone() }
ty: self.from_unifier_type(unifier, primitives, *ty, cache), }
},
_ => unreachable!("{:?}", ty_enum.get_type_name()), _ => unreachable!("{:?}", ty_enum.get_type_name()),
}; };
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() { let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
@ -231,7 +226,7 @@ impl ConcreteTypeStore {
return if let Some(ty) = ty { return if let Some(ty) = ty {
*ty *ty
} else { } else {
*ty = Some(unifier.get_dummy_var().0); *ty = Some(unifier.get_dummy_var().ty);
ty.unwrap() ty.unwrap()
}; };
} }
@ -259,9 +254,6 @@ impl ConcreteTypeStore {
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache)) .map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(), .collect(),
}, },
ConcreteTypeEnum::TList { ty } => {
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TVirtual { ty } => { ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
} }
@ -273,10 +265,10 @@ impl ConcreteTypeStore {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
}) })
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: params params: into_var_map(params.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<HashMap<_, _>>(), })),
}, },
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args args: args
@ -288,15 +280,13 @@ impl ConcreteTypeStore {
}) })
.collect(), .collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: vars vars: into_var_map(vars.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<HashMap<_, _>>(), })),
}), }),
ConcreteTypeEnum::TConstant { value, ty } => TypeEnum::TConstant { ConcreteTypeEnum::TLiteral { values, .. } => {
value: value.clone(), TypeEnum::TLiteral { values: values.clone(), loc: None }
ty: self.to_unifier_type(unifier, primitives, *ty, cache),
loc: None,
} }
}; };
let result = unifier.add_ty(result); let result = unifier.add_ty(result);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,132 @@
use inkwell::attributes::{Attribute, AttributeLoc};
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue};
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`
///
/// Arguments:
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
///
/// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly"
/// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue`
///
macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
};
("binary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
};
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
pub fn $fn_name<'ctx>(
ctx: &CodeGenContext<'ctx, '_>
$(,$args: FloatValue<'ctx>)*,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in [$($attributes),*] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
};
}
generate_extern_fn!("unary", call_tan, "tan");
generate_extern_fn!("unary", call_asin, "asin");
generate_extern_fn!("unary", call_acos, "acos");
generate_extern_fn!("unary", call_atan, "atan");
generate_extern_fn!("unary", call_sinh, "sinh");
generate_extern_fn!("unary", call_cosh, "cosh");
generate_extern_fn!("unary", call_tanh, "tanh");
generate_extern_fn!("unary", call_asinh, "asinh");
generate_extern_fn!("unary", call_acosh, "acosh");
generate_extern_fn!("unary", call_atanh, "atanh");
generate_extern_fn!("unary", call_expm1, "expm1");
generate_extern_fn!(
"unary",
call_cbrt,
"cbrt",
"mustprogress",
"nofree",
"nosync",
"nounwind",
"readonly",
"willreturn"
);
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
generate_extern_fn!("binary", call_atan2, "atan2");
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "ldexp";
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
codegen::{expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext}, codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, Type},
@ -92,6 +92,18 @@ pub trait CodeGenerator {
gen_var(ctx, ty, name) gen_var(ctx, ty, name)
} }
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
}
/// Return a pointer pointing to the target of the expression. /// Return a pointer pointing to the target of the expression.
fn gen_store_target<'ctx>( fn gen_store_target<'ctx>(
&mut self, &mut self,
@ -131,8 +143,8 @@ pub trait CodeGenerator {
gen_while(self, ctx, stmt) gen_while(self, ctx, stmt)
} }
/// Generate code for a while expression. /// Generate code for a for expression.
/// Return true if the while loop must early return /// Return true if the for loop must early return
fn gen_for( fn gen_for(
&mut self, &mut self,
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
@ -198,7 +210,7 @@ pub trait CodeGenerator {
fn bool_to_i1<'ctx>( fn bool_to_i1<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value) bool_to_i1(&ctx.builder, bool_value)
} }
@ -207,7 +219,7 @@ pub trait CodeGenerator {
fn bool_to_i8<'ctx>( fn bool_to_i8<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value) bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
} }
@ -227,7 +239,6 @@ impl DefaultCodeGenerator {
} }
impl CodeGenerator for DefaultCodeGenerator { impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`]. /// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str { fn get_name(&self) -> &str {
&self.name &self.name

View File

@ -1,199 +0,0 @@
typedef _BitInt(8) int8_t;
typedef unsigned _BitInt(8) uint8_t;
typedef _BitInt(32) int32_t;
typedef unsigned _BitInt(32) uint32_t;
typedef _BitInt(64) int64_t;
typedef unsigned _BitInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
T base, \
T exp \
) { \
T res = (T)1; \
/* repeated squaring method */ \
do { \
if (exp & 1) res *= base; /* for n odd */ \
exp >>= 1; \
base *= base; \
} while (exp); \
return res; \
} \
DEF_INT_EXP(int32_t)
DEF_INT_EXP(int64_t)
DEF_INT_EXP(uint32_t)
DEF_INT_EXP(uint64_t)
int32_t __nac3_slice_index_bound(int32_t i, const int32_t len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
int32_t __nac3_range_slice_len(const int32_t start, const int32_t end, const int32_t step) {
int32_t diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
int32_t __nac3_list_slice_assign_var_size(
int32_t dest_start,
int32_t dest_end,
int32_t dest_step,
uint8_t *dest_arr,
int32_t dest_arr_len,
int32_t src_start,
int32_t src_end,
int32_t src_step,
uint8_t *src_arr,
int32_t src_arr_len,
const int32_t size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const int32_t src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const int32_t dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
MAX(dest_start, dest_end) < MIN(src_start, src_end)
|| MAX(src_start, src_end) < MIN(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
int32_t src_ind = src_start;
int32_t dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}

View File

@ -1,15 +1,26 @@
use crate::typecheck::typedef::Type; use crate::{typecheck::typedef::Type, util::SizeVariant};
use super::{CodeGenContext, CodeGenerator}; mod test;
use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, NpArrayType,
NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics, CodeGenContext, CodeGenerator,
};
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::Module, module::Module,
types::BasicTypeEnum, types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType},
values::{FloatValue, IntValue, PointerValue}, values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
#[must_use] #[must_use]
@ -34,8 +45,8 @@ pub fn load_irrt(ctx: &Context) -> Module {
// repeated squaring method adapted from GNU Scientific Library: // repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx>( pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut dyn CodeGenerator, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>, base: IntValue<'ctx>,
exp: IntValue<'ctx>, exp: IntValue<'ctx>,
@ -54,12 +65,15 @@ pub fn integer_power<'ctx>(
ctx.module.add_function(symbol, fn_type, None) ctx.module.add_function(symbol, fn_type, None)
}); });
// throw exception when exp < 0 // throw exception when exp < 0
let ge_zero = ctx.builder.build_int_compare( let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE, IntPredicate::SGE,
exp, exp,
exp.get_type().const_zero(), exp.get_type().const_zero(),
"assert_int_pow_ge_0", "assert_int_pow_ge_0",
); )
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
ge_zero, ge_zero,
@ -70,13 +84,14 @@ pub fn integer_power<'ctx>(
); );
ctx.builder ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.into_int_value() .map(Either::unwrap_left)
.unwrap()
} }
pub fn calculate_len_for_slice_range<'ctx>( pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut dyn CodeGenerator, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>, start: IntValue<'ctx>,
end: IntValue<'ctx>, end: IntValue<'ctx>,
@ -90,12 +105,10 @@ pub fn calculate_len_for_slice_range<'ctx>(
}); });
// assert step != 0, throw exception if not // assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare( let not_zero = ctx
IntPredicate::NE, .builder
step, .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
step.get_type().const_zero(), .unwrap();
"range_step_ne",
);
ctx.make_assert( ctx.make_assert(
generator, generator,
not_zero, not_zero,
@ -106,10 +119,10 @@ pub fn calculate_len_for_slice_range<'ctx>(
); );
ctx.builder ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap() .unwrap()
.into_int_value()
} }
/// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// NOTE: the output value of the end index of this function should be compared ***inclusively***,
@ -158,13 +171,12 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
step: &Option<Box<Expr<Option<Type>>>>, step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G, generator: &mut G,
list: PointerValue<'ctx>, length: IntValue<'ctx>,
) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> { ) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let length = ctx.build_gep_and_load(list, &[zero, one], Some("length")).into_int_value(); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32");
Ok(Some(match (start, end, step) { Ok(Some(match (start, end, step) {
(s, e, None) => ( (s, e, None) => (
if let Some(s) = s.as_ref() { if let Some(s) = s.as_ref() {
@ -184,7 +196,7 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
} else { } else {
length length
}; };
ctx.builder.build_int_sub(e, one, "final_end") ctx.builder.build_int_sub(e, one, "final_end").unwrap()
}, },
one, one,
), ),
@ -192,15 +204,18 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
let step = if let Some(v) = generator.gen_expr(ctx, step)? { let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value() v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else { } else {
return Ok(None) return Ok(None);
}; };
// assert step != 0, throw exception if not // assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare( let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE, IntPredicate::NE,
step, step,
step.get_type().const_zero(), step.get_type().const_zero(),
"range_step_ne", "range_step_ne",
); )
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
not_zero, not_zero,
@ -209,49 +224,66 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
[None, None, None], [None, None, None],
ctx.current_loc, ctx.current_loc,
); );
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1"); let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg"); let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
( (
match s { match s {
Some(s) => { Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else { let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None) return Ok(None);
}; };
ctx.builder ctx.builder
.build_select( .build_select(
ctx.builder.build_and( ctx.builder
ctx.builder.build_int_compare( .build_and(
ctx.builder
.build_int_compare(
IntPredicate::EQ, IntPredicate::EQ,
s, s,
length, length,
"s_eq_len", "s_eq_len",
), )
.unwrap(),
neg, neg,
"should_minus_one", "should_minus_one",
), )
ctx.builder.build_int_sub(s, one, "s_min"), .unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
s, s,
"final_start", "final_start",
) )
.into_int_value() .map(BasicValueEnum::into_int_value)
.unwrap()
} }
None => ctx.builder.build_select(neg, len_id, zero, "stt").into_int_value(), None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
}, },
match e { match e {
Some(e) => { Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else { let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None) return Ok(None);
}; };
ctx.builder ctx.builder
.build_select( .build_select(
neg, neg,
ctx.builder.build_int_add(e, one, "end_add_one"), ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one"), ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
"final_end", "final_end",
) )
.into_int_value() .map(BasicValueEnum::into_int_value)
.unwrap()
} }
None => ctx.builder.build_select(neg, zero, len_id, "end").into_int_value(), None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
}, },
step, step,
) )
@ -277,27 +309,28 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
let i = if let Some(v) = generator.gen_expr(ctx, i)? { let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else { } else {
return Ok(None) return Ok(None);
}; };
Ok(Some(ctx Ok(Some(
.builder ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind") .build_call(func, &[i.into(), length.into()], "bounded_ind")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.unwrap() .map(Either::unwrap_left)
.into_int_value())) .unwrap(),
))
} }
/// This function handles 'end' **inclusively**. /// This function handles 'end' **inclusively**.
/// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step'). /// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function /// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx>( pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut dyn CodeGenerator, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>, ty: BasicTypeEnum<'ctx>,
dest_arr: PointerValue<'ctx>, dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: PointerValue<'ctx>, src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) { ) {
let size_ty = generator.get_size_type(ctx.ctx); let size_ty = generator.get_size_type(ctx.ctx);
@ -326,76 +359,63 @@ pub fn list_slice_assignment<'ctx>(
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero], Some("dest.addr")); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr = ctx.builder.build_pointer_cast( let dest_arr_ptr =
dest_arr_ptr.into_pointer_value(), ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
elem_ptr_type, let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
"dest_arr_ptr_cast", let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one], Some("dest.len")).into_int_value(); let src_arr_ptr =
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32"); ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero], Some("src.addr")); let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_arr_ptr = ctx.builder.build_pointer_cast( let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
src_arr_ptr.into_pointer_value(),
elem_ptr_type,
"src_arr_ptr_cast",
);
let src_len = ctx.build_gep_and_load(src_arr, &[zero, one], Some("src.len")).into_int_value();
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32");
// index in bound and positive should be done // index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied // throw exception if not satisfied
let src_end = ctx.builder let src_end = ctx
.builder
.build_select( .build_select(
ctx.builder.build_int_compare( ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
IntPredicate::SLT, ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
src_idx.2, ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
zero,
"is_neg",
),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one"),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one"),
"final_e", "final_e",
) )
.into_int_value(); .map(BasicValueEnum::into_int_value)
let dest_end = ctx.builder .unwrap();
let dest_end = ctx
.builder
.build_select( .build_select(
ctx.builder.build_int_compare( ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
IntPredicate::SLT, ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
dest_idx.2, ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
zero,
"is_neg",
),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one"),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one"),
"final_e", "final_e",
) )
.into_int_value(); .map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len = let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len = let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx.builder.build_int_compare( let src_eq_dest = ctx
IntPredicate::EQ, .builder
src_slice_len, .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
dest_slice_len, .unwrap();
"slice_src_eq_dest", let src_slt_dest = ctx
); .builder
let src_slt_dest = ctx.builder.build_int_compare( .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
IntPredicate::SLT, .unwrap();
src_slice_len, let dest_step_eq_one = ctx
dest_slice_len, .builder
"slice_src_slt_dest", .build_int_compare(
);
let dest_step_eq_one = ctx.builder.build_int_compare(
IntPredicate::EQ, IntPredicate::EQ,
dest_idx.2, dest_idx.2,
dest_idx.2.get_type().const_int(1, false), dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one", "slice_dest_step_eq_one",
); )
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1"); .unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond"); let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
cond, cond,
@ -425,34 +445,34 @@ pub fn list_slice_assignment<'ctx>(
BasicTypeEnum::StructType(t) => t.size_of().unwrap(), BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => unreachable!(), _ => unreachable!(),
}; };
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size") ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
} }
.into(), .into(),
]; ];
ctx.builder ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign") .build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.into_int_value() .map(Either::unwrap_left)
.unwrap()
}; };
// update length // update length
let need_update = let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update"); ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update"); let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont"); let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb); ctx.builder.position_at_end(update_bb);
let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") }; let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len"); dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_store(dest_len_ptr, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
} }
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result. /// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx>( pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut dyn CodeGenerator, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>, v: FloatValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
@ -461,18 +481,20 @@ pub fn call_isinf<'ctx>(
ctx.module.add_function("__nac3_isinf", fn_type, None) ctx.module.add_function("__nac3_isinf", fn_type, None)
}); });
let ret = ctx.builder let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf") .build_call(intrinsic_fn, &[v.into()], "isinf")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.into_int_value(); .map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret) generator.bool_to_i1(ctx, ret)
} }
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result. /// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx>( pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut dyn CodeGenerator, generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>, v: FloatValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
@ -481,20 +503,19 @@ pub fn call_isnan<'ctx>(
ctx.module.add_function("__nac3_isnan", fn_type, None) ctx.module.add_function("__nac3_isnan", fn_type, None)
}); });
let ret = ctx.builder let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan") .build_call(intrinsic_fn, &[v.into()], "isnan")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.into_int_value(); .map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret) generator.bool_to_i1(ctx, ret)
} }
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. /// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>( pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
@ -504,16 +525,14 @@ pub fn call_gamma<'ctx>(
ctx.builder ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma") .build_call(intrinsic_fn, &[v.into()], "gamma")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_float_value))
.into_float_value() .map(Either::unwrap_left)
.unwrap()
} }
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. /// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>( pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
@ -523,16 +542,14 @@ pub fn call_gammaln<'ctx>(
ctx.builder ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln") .build_call(intrinsic_fn, &[v.into()], "gammaln")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_float_value))
.into_float_value() .map(Either::unwrap_left)
.unwrap()
} }
/// Generates a call to `j0` in IR. Returns an `f64` representing the result. /// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>( pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
@ -542,7 +559,433 @@ pub fn call_j0<'ctx>(
ctx.builder ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0") .build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
match ty.get_bit_width() {
32 => SizeVariant::Bits32,
64 => SizeVariant::Bits64,
_ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()),
}
}
fn get_size_type_dependent_function<'ctx, BuildFuncTypeFn>(
ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>,
base_name: &str,
build_func_type: BuildFuncTypeFn,
) -> FunctionValue<'ctx>
where
BuildFuncTypeFn: Fn() -> FunctionType<'ctx>,
{
let mut fn_name = base_name.to_owned();
match get_size_variant(size_type) {
SizeVariant::Bits32 => {
// The original fn_name is the correct function name
}
SizeVariant::Bits64 => {
// Append "64" at the end, this is the naming convention for 64-bit
fn_name.push_str("64");
}
}
// Get (or declare then get if does not exist) the corresponding function
ctx.module.get_function(&fn_name).unwrap_or_else(|| {
let fn_type = build_func_type();
ctx.module.add_function(&fn_name, fn_type, None)
})
}
fn get_ndarray_struct_ptr<'ctx>(ctx: &'ctx Context, size_type: IntType<'ctx>) -> PointerType<'ctx> {
let i8_type = ctx.i8_type();
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
let struct_ty = ndarray_ty.fields().whole_struct.as_struct_type(ctx);
struct_ty.ptr_type(AddressSpace::default())
}
pub fn call_nac3_ndarray_size<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
) -> IntValue<'ctx> {
let size_type = ndarray.ty.size_type;
let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || {
size_type.fn_type(&[get_ndarray_struct_ptr(ctx.ctx, size_type).into()], false)
});
ctx.builder
.build_call(function, &[ndarray.ptr.into()], "size")
.unwrap()
.try_as_basic_value() .try_as_basic_value()
.unwrap_left() .unwrap_left()
.into_float_value() .into_int_value()
} }

View File

@ -0,0 +1,26 @@
#[cfg(test)]
mod tests {
use std::{path::Path, process::Command};
#[test]
fn run_irrt_test() {
assert!(
cfg!(feature = "test"),
"Please do `cargo test -F test` to compile `irrt_test.out` and run test"
);
let irrt_test_out_path = Path::new(concat!(env!("OUT_DIR"), "/irrt_test.out"));
let output = Command::new(irrt_test_out_path.to_str().unwrap()).output().unwrap();
if !output.status.success() {
eprintln!("irrt_test failed with status {}:", output.status);
eprintln!("====== stdout ======");
eprintln!("{}", String::from_utf8(output.stdout).unwrap());
eprintln!("====== stderr ======");
eprintln!("{}", String::from_utf8(output.stderr).unwrap());
eprintln!("====================");
panic!("irrt_test failed");
}
}
}

View File

@ -0,0 +1,308 @@
use crate::codegen::CodeGenContext;
use inkwell::context::Context;
use inkwell::intrinsics::Intrinsic;
use inkwell::types::AnyTypeEnum::IntType;
use inkwell::types::FloatType;
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use inkwell::AddressSpace;
use itertools::Either;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16";
}
if ft == ctx.f32_type() {
return "f32";
}
if ft == ctx.f64_type() {
return "f64";
}
if ft == ctx.f128_type() {
return "f128";
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80";
}
if ft == ctx.ppc_f128_type() {
return "ppcf128";
}
unreachable!()
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
const FN_NAME: &str = "llvm.stacksave";
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
///
/// - `ptr`: The pointer storing the address to restore the stack to.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
/*
SEE https://github.com/TheDan64/inkwell/issues/496
We want `llvm.stackrestore`, but the following would generate `llvm.stackrestore.p0i8`.
```ignore
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
```
Temp workaround by manually declaring the intrinsic with the correct function name instead.
*/
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
/// * `len` - The number of bytes to copy.
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
pub fn call_memcpy<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
const FN_NAME: &str = "llvm.memcpy";
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
dest.get_type().get_element_type().into_int_type().get_bit_width(),
src.get_type().get_element_type().into_int_type().get_bit_width(),
);
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
let llvm_dest_t = dest.get_type();
let llvm_src_t = src.get_type();
let llvm_len_t = len.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(
&ctx.module,
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
)
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
.unwrap();
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
pub fn call_memcpy_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bitcast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bitcast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
///
/// Arguments:
/// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type)
/// Use `BasicValueEnum::into_int_value` for Integer return type and `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body {
($ctx:ident, $name:ident, $llvm_name:literal, $map_fn:expr, $llvm_ty:ident $(,$val:ident)*) => {{
const FN_NAME: &str = concat!("llvm.", $llvm_name);
let intrinsic_fn = Intrinsic::find(FN_NAME).and_then(|intrinsic| intrinsic.get_declaration(&$ctx.module, &[$llvm_ty.into()])).unwrap();
$ctx.builder.build_call(intrinsic_fn, &[$($val.into()),*], $name.unwrap_or_default()).map(CallSiteValue::try_as_basic_value).map(|v| v.map_left($map_fn)).map(Either::unwrap_left).unwrap()
}};
}
/// Macro to generate the llvm intrinsic function using [`generate_llvm_intrinsic_fn_body`].
///
/// Arguments:
/// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn {
("float", $fn_name:ident, $llvm_name:literal, $val:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_ty = $val.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val)
}
};
("float", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: FloatValue<'ctx>,
$val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!($val1.get_type(), $val2.get_type());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val1, $val2)
}
};
("int", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: IntValue<'ctx>,
$val2: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!($val1.get_type().get_bit_width(), $val2.get_type().get_bit_width());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_int_value, llvm_ty, $val1, $val2)
}
};
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let src_type = src.get_type();
generate_llvm_intrinsic_fn_body!(
ctx,
name,
"abs",
BasicValueEnum::into_int_value,
src_type,
src,
is_int_min_poison
)
}
generate_llvm_intrinsic_fn!("int", call_int_smax, "smax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_smin, "smin", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umax, "umax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umin, "umin", a, b);
generate_llvm_intrinsic_fn!("int", call_expect, "expect", val, expected_val);
generate_llvm_intrinsic_fn!("float", call_float_sqrt, "sqrt", val);
generate_llvm_intrinsic_fn!("float", call_float_sin, "sin", val);
generate_llvm_intrinsic_fn!("float", call_float_cos, "cos", val);
generate_llvm_intrinsic_fn!("float", call_float_pow, "pow", val, power);
generate_llvm_intrinsic_fn!("float", call_float_exp, "exp", val);
generate_llvm_intrinsic_fn!("float", call_float_exp2, "exp2", val);
generate_llvm_intrinsic_fn!("float", call_float_log, "log", val);
generate_llvm_intrinsic_fn!("float", call_float_log10, "log10", val);
generate_llvm_intrinsic_fn!("float", call_float_log2, "log2", val);
generate_llvm_intrinsic_fn!("float", call_float_fabs, "fabs", src);
generate_llvm_intrinsic_fn!("float", call_float_minnum, "minnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_maxnum, "maxnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_copysign, "copysign", mag, sgn);
generate_llvm_intrinsic_fn!("float", call_float_floor, "floor", val);
generate_llvm_intrinsic_fn!("float", call_float_ceil, "ceil", val);
generate_llvm_intrinsic_fn!("float", call_float_round, "round", val);
generate_llvm_intrinsic_fn!("float", call_float_rint, "rint", val);
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.powi";
let llvm_val_t = val.get_type();
let llvm_power_t = power.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

View File

@ -1,6 +1,7 @@
use crate::{ use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{TopLevelContext, TopLevelDef}, toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -8,24 +9,22 @@ use crate::{
}; };
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{ use inkwell::{
AddressSpace,
IntPredicate,
OptimizationLevel,
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
builder::Builder, builder::Builder,
context::Context, context::Context,
debug_info::{
AsDIScope, DICompileUnit, DIFlagsConstants, DIScope, DISubprogram, DebugInfoBuilder,
},
module::Module, module::Module,
passes::PassBuilderOptions, passes::PassBuilderOptions,
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
debug_info::{ AddressSpace, IntPredicate, OptimizationLevel,
DebugInfoBuilder, DICompileUnit, DISubprogram, AsDIScope, DIFlagsConstants, DIScope
},
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::{Stmt, StrRef, Location}; use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::{ use std::sync::{
@ -34,10 +33,15 @@ use std::sync::{
}; };
use std::thread; use std::thread;
pub mod builtin_fns;
pub mod classes;
pub mod concrete_type; pub mod concrete_type;
pub mod expr; pub mod expr;
pub mod extern_fns;
mod generator; mod generator;
pub mod irrt; pub mod irrt;
pub mod llvm_intrinsics;
pub mod numpy;
pub mod stmt; pub mod stmt;
#[cfg(test)] #[cfg(test)]
@ -80,7 +84,6 @@ pub struct CodeGenTargetMachineOptions {
} }
impl CodeGenTargetMachineOptions { impl CodeGenTargetMachineOptions {
/// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine. /// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine.
/// Other options are set to defaults. /// Other options are set to defaults.
#[must_use] #[must_use]
@ -109,13 +112,11 @@ impl CodeGenTargetMachineOptions {
/// ///
/// See [`Target::create_target_machine`]. /// See [`Target::create_target_machine`].
#[must_use] #[must_use]
pub fn create_target_machine( pub fn create_target_machine(&self, level: OptimizationLevel) -> Option<TargetMachine> {
&self,
level: OptimizationLevel,
) -> Option<TargetMachine> {
let triple = TargetTriple::create(self.triple.as_str()); let triple = TargetTriple::create(self.triple.as_str());
let target = Target::from_triple(&triple) let target = Target::from_triple(&triple).unwrap_or_else(|_| {
.unwrap_or_else(|_| panic!("could not create target from target triple {}", self.triple)); panic!("could not create target from target triple {}", self.triple)
});
target.create_target_machine( target.create_target_machine(
&triple, &triple,
@ -123,7 +124,7 @@ impl CodeGenTargetMachineOptions {
self.features.as_str(), self.features.as_str(),
level, level,
self.reloc_mode, self.reloc_mode,
self.code_model self.code_model,
) )
} }
} }
@ -134,24 +135,23 @@ pub struct CodeGenContext<'ctx, 'a> {
/// The [Builder] instance for creating LLVM IR statements. /// The [Builder] instance for creating LLVM IR statements.
pub builder: Builder<'ctx>, pub builder: Builder<'ctx>,
/// The [DebugInfoBuilder], [compilation unit information][DICompileUnit], and /// The [`DebugInfoBuilder`], [compilation unit information][DICompileUnit], and
/// [scope information][DIScope] of this context. /// [scope information][DIScope] of this context.
pub debug_info: (DebugInfoBuilder<'ctx>, DICompileUnit<'ctx>, DIScope<'ctx>), pub debug_info: (DebugInfoBuilder<'ctx>, DICompileUnit<'ctx>, DIScope<'ctx>),
/// The module for which [this context][CodeGenContext] is generating into. /// The module for which [this context][CodeGenContext] is generating into.
pub module: Module<'ctx>, pub module: Module<'ctx>,
/// The [TopLevelContext] associated with [this context][CodeGenContext]. /// The [`TopLevelContext`] associated with [this context][CodeGenContext].
pub top_level: &'a TopLevelContext, pub top_level: &'a TopLevelContext,
pub unifier: Unifier, pub unifier: Unifier,
pub resolver: Arc<dyn SymbolResolver + Send + Sync>, pub resolver: Arc<dyn SymbolResolver + Send + Sync>,
pub static_value_store: Arc<Mutex<StaticValueStore>>, pub static_value_store: Arc<Mutex<StaticValueStore>>,
/// A [HashMap] containing the mapping between the names of variables currently in-scope and /// A [`HashMap`] containing the mapping between the names of variables currently in-scope and
/// its value information. /// its value information.
pub var_assignment: HashMap<StrRef, VarValue<'ctx>>, pub var_assignment: HashMap<StrRef, VarValue<'ctx>>,
///
pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>, pub type_cache: HashMap<Type, BasicTypeEnum<'ctx>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub calls: Arc<HashMap<CodeLocation, CallId>>, pub calls: Arc<HashMap<CodeLocation, CallId>>,
@ -160,24 +160,24 @@ pub struct CodeGenContext<'ctx, 'a> {
/// Cache for constant strings. /// Cache for constant strings.
pub const_strings: HashMap<String, BasicValueEnum<'ctx>>, pub const_strings: HashMap<String, BasicValueEnum<'ctx>>,
/// [BasicBlock] containing all `alloca` statements for the current function. /// [`BasicBlock`] containing all `alloca` statements for the current function.
pub init_bb: BasicBlock<'ctx>, pub init_bb: BasicBlock<'ctx>,
pub exception_val: Option<PointerValue<'ctx>>, pub exception_val: Option<PointerValue<'ctx>>,
/// The header and exit basic blocks of a loop in this context. See /// The header and exit basic blocks of a loop in this context. See
/// https://llvm.org/docs/LoopTerminology.html for explanation of these terminology. /// <https://llvm.org/docs/LoopTerminology.html> for explanation of these terminology.
pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>, pub loop_target: Option<(BasicBlock<'ctx>, BasicBlock<'ctx>)>,
/// The target [BasicBlock] to jump to when performing stack unwind. /// The target [`BasicBlock`] to jump to when performing stack unwind.
pub unwind_target: Option<BasicBlock<'ctx>>, pub unwind_target: Option<BasicBlock<'ctx>>,
/// The target [BasicBlock] to jump to before returning from the function. /// The target [`BasicBlock`] to jump to before returning from the function.
/// ///
/// If this field is [None] when generating a return from a function, `ret` with no argument can /// If this field is [None] when generating a return from a function, `ret` with no argument can
/// be emitted. /// be emitted.
pub return_target: Option<BasicBlock<'ctx>>, pub return_target: Option<BasicBlock<'ctx>>,
/// The [PointerValue] containing the return value of the function. /// The [`PointerValue`] containing the return value of the function.
pub return_buffer: Option<PointerValue<'ctx>>, pub return_buffer: Option<PointerValue<'ctx>>,
// outer catch clauses // outer catch clauses
@ -186,7 +186,7 @@ pub struct CodeGenContext<'ctx, 'a> {
/// Whether `sret` is needed for the first parameter of the function. /// Whether `sret` is needed for the first parameter of the function.
/// ///
/// See [need_sret]. /// See [`need_sret`].
pub need_sret: bool, pub need_sret: bool,
/// The current source location. /// The current source location.
@ -194,7 +194,6 @@ pub struct CodeGenContext<'ctx, 'a> {
} }
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
/// contains a [terminator statement][BasicBlock::get_terminator]. /// contains a [terminator statement][BasicBlock::get_terminator].
pub fn is_terminated(&self) -> bool { pub fn is_terminated(&self) -> bool {
@ -236,11 +235,10 @@ pub struct WorkerRegistry {
static_value_store: Arc<Mutex<StaticValueStore>>, static_value_store: Arc<Mutex<StaticValueStore>>,
/// LLVM-related options for code generation. /// LLVM-related options for code generation.
llvm_options: CodeGenLLVMOptions, pub llvm_options: CodeGenLLVMOptions,
} }
impl WorkerRegistry { impl WorkerRegistry {
/// Creates workers for this registry. /// Creates workers for this registry.
#[must_use] #[must_use]
pub fn create_workers<G: CodeGenerator + Send + 'static>( pub fn create_workers<G: CodeGenerator + Send + 'static>(
@ -275,9 +273,15 @@ impl WorkerRegistry {
let registry = registry.clone(); let registry = registry.clone();
let registry2 = registry.clone(); let registry2 = registry.clone();
let f = f.clone(); let f = f.clone();
let handle = thread::spawn(move || {
let worker_thread_name =
format!("codegen-worker-{worker_id}", worker_id = generator.get_name());
let handle = thread::Builder::new()
.name(worker_thread_name)
.spawn(move || {
registry.worker_thread(generator.as_mut(), &f); registry.worker_thread(generator.as_mut(), &f);
}); })
.unwrap();
let handle = thread::spawn(move || { let handle = thread::spawn(move || {
if let Err(e) = handle.join() { if let Err(e) = handle.join() {
if let Some(e) = e.downcast_ref::<&'static str>() { if let Some(e) = e.downcast_ref::<&'static str>() {
@ -362,7 +366,11 @@ impl WorkerRegistry {
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
self.wait_condvar.notify_all(); self.wait_condvar.notify_all();
} }
assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().sorted().join("\n----------\n")); assert!(
errors.is_empty(),
"Codegen error: {}",
errors.into_iter().sorted().join("\n----------\n")
);
let result = module.verify(); let result = module.verify();
if let Err(err) = result { if let Err(err) = result {
@ -375,13 +383,20 @@ impl WorkerRegistry {
.llvm_options .llvm_options
.target .target
.create_target_machine(self.llvm_options.opt_level) .create_target_machine(self.llvm_options.opt_level)
.unwrap_or_else(|| panic!("could not create target machine from properties {:?}", self.llvm_options.target)); .unwrap_or_else(|| {
panic!(
"could not create target machine from properties {:?}",
self.llvm_options.target
)
});
let passes = format!("default<O{}>", self.llvm_options.opt_level as u32); let passes = format!("default<O{}>", self.llvm_options.opt_level as u32);
let result = module.run_passes(passes.as_str(), &target_machine, pass_options); let result = module.run_passes(passes.as_str(), &target_machine, pass_options);
if let Err(err) = result { if let Err(err) = result {
panic!("Failed to run optimization for module `{}`: {}", panic!(
"Failed to run optimization for module `{}`: {}",
module.get_name().to_str().unwrap(), module.get_name().to_str().unwrap(),
err.to_string()); err.to_string()
);
} }
f.run(&module); f.run(&module);
@ -407,14 +422,14 @@ pub struct CodeGenTask {
/// ///
/// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable /// This function is used to obtain the in-memory representation of `ty`, e.g. a `bool` variable
/// would be represented by an `i8`. /// would be represented by an `i8`.
fn get_llvm_type<'ctx>( #[allow(clippy::too_many_arguments)]
fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut dyn CodeGenerator, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
primitives: &PrimitiveStore,
ty: Type, ty: Type,
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
use TypeEnum::*; use TypeEnum::*;
@ -424,29 +439,51 @@ fn get_llvm_type<'ctx>(
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
// check to avoid treating primitives other than Option as classes // check to avoid treating non-class primitives as classes
if obj_id.0 <= 10 { if PrimDef::contains_id(*obj_id) {
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref()) return match &*unifier.get_ty_immutable(ty) {
{ TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
( get_llvm_type(
TObj { obj_id, params, .. },
TObj { obj_id: opt_id, .. },
) if *obj_id == *opt_id => {
return get_llvm_type(
ctx, ctx,
module, module,
generator, generator,
unifier, unifier,
top_level, top_level,
type_cache, type_cache,
primitives,
*params.iter().next().unwrap().1, *params.iter().next().unwrap().1,
) )
.ptr_type(AddressSpace::default()) .ptr_type(AddressSpace::default())
.into(); .into()
} }
_ => unreachable!("must be option type"),
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let element_type = get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
*params.iter().next().unwrap().1,
);
ListType::new(generator, ctx, element_type).as_base_type().into()
} }
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
}
_ => unreachable!(
"LLVM type for primitive {} is missing",
unifier.stringify(ty)
),
};
} }
// a struct with fields in the order of declaration // a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read(); let top_level_defs = top_level.definitions.read();
@ -462,7 +499,7 @@ fn get_llvm_type<'ctx>(
let struct_type = ctx.opaque_struct_type(&name); let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert( type_cache.insert(
unifier.get_representative(ty), unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into() struct_type.ptr_type(AddressSpace::default()).into(),
); );
let fields = fields_list let fields = fields_list
.iter() .iter()
@ -474,7 +511,6 @@ fn get_llvm_type<'ctx>(
unifier, unifier,
top_level, top_level,
type_cache, type_cache,
primitives,
fields[&f.0].0, fields[&f.0].0,
) )
}) })
@ -482,31 +518,18 @@ fn get_llvm_type<'ctx>(
struct_type.set_body(&fields, false); struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into() struct_type.ptr_type(AddressSpace::default()).into()
}; };
return ty return ty;
} }
TTuple { ty } => { TTuple { ty } => {
// a struct with fields in the order present in the tuple // a struct with fields in the order present in the tuple
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
get_llvm_type( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
)
}) })
.collect_vec(); .collect_vec();
ctx.struct_type(&fields, false).into() ctx.struct_type(&fields, false).into()
} }
TList { ty } => {
// a struct with an integer and a pointer to an array
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
);
let fields = [
element_type.ptr_type(AddressSpace::default()).into(),
generator.get_size_type(ctx).into(),
];
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
}
TVirtual { .. } => unimplemented!(), TVirtual { .. } => unimplemented!(),
_ => unreachable!("{}", ty_enum.get_type_name()), _ => unreachable!("{}", ty_enum.get_type_name()),
}; };
@ -524,10 +547,11 @@ fn get_llvm_type<'ctx>(
/// ABI representation is that the in-memory representation must be at least byte-sized and must /// ABI representation is that the in-memory representation must be at least byte-sized and must
/// be byte-aligned for the variable to be addressable in memory, whereas there is no such /// be byte-aligned for the variable to be addressable in memory, whereas there is no such
/// restriction for ABI representations. /// restriction for ABI representations.
fn get_llvm_abi_type<'ctx>( #[allow(clippy::too_many_arguments)]
fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>, module: &Module<'ctx>,
generator: &mut dyn CodeGenerator, generator: &mut G,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>, type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
@ -539,8 +563,8 @@ fn get_llvm_abi_type<'ctx>(
return if unifier.unioned(ty, primitives.bool) { return if unifier.unioned(ty, primitives.bool) {
ctx.bool_type().into() ctx.bool_type().into()
} else { } else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, primitives, ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
} };
} }
/// Whether `sret` is needed for a return value with type `ty`. /// Whether `sret` is needed for a return value with type `ty`.
@ -556,8 +580,9 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
match ty { match ty {
BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false, BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false,
BasicTypeEnum::FloatType(_) if maybe_large => false, BasicTypeEnum::FloatType(_) if maybe_large => false,
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => {
ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false)), ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false))
}
_ => true, _ => true,
} }
} }
@ -565,14 +590,18 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
} }
/// Implementation for generating LLVM IR for a function. /// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> ( pub fn gen_func_impl<
'ctx,
G: CodeGenerator,
F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>,
>(
context: &'ctx Context, context: &'ctx Context,
generator: &mut G, generator: &mut G,
registry: &WorkerRegistry, registry: &WorkerRegistry,
builder: Builder<'ctx>, builder: Builder<'ctx>,
module: Module<'ctx>, module: Module<'ctx>,
task: CodeGenTask, task: CodeGenTask,
codegen_function: F codegen_function: F,
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { ) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let top_level_ctx = registry.top_level_ctx.clone(); let top_level_ctx = registry.top_level_ctx.clone();
let static_value_store = registry.static_value_store.clone(); let static_value_store = registry.static_value_store.clone();
@ -580,6 +609,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index]; let (unifier, primitives) = &top_level_ctx.unifiers.read()[task.unifier_index];
(Unifier::from_shared_unifier(unifier), *primitives) (Unifier::from_shared_unifier(unifier), *primitives)
}; };
unifier.put_primitive_store(&primitives);
unifier.top_level = Some(top_level_ctx.clone()); unifier.top_level = Some(top_level_ctx.clone());
let mut cache = HashMap::new(); let mut cache = HashMap::new();
@ -613,6 +643,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
str: unifier.get_representative(primitives.str), str: unifier.get_representative(primitives.str),
exception: unifier.get_representative(primitives.exception), exception: unifier.get_representative(primitives.exception),
option: unifier.get_representative(primitives.option), option: unifier.get_representative(primitives.option),
..primitives
}; };
let mut type_cache: HashMap<_, _> = [ let mut type_cache: HashMap<_, _> = [
@ -634,10 +665,10 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
str_type.set_body(&fields, false); str_type.set_body(&fields, false);
str_type.into() str_type.into()
} }
Some(t) => t.as_basic_type_enum() Some(t) => t.as_basic_type_enum(),
} }
}), }),
(primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::default()).into()), (primitives.range, RangeType::new(context).as_base_type().into()),
(primitives.exception, { (primitives.exception, {
let name = "Exception"; let name = "Exception";
if let Some(t) = module.get_struct_type(name) { if let Some(t) = module.get_struct_type(name) {
@ -651,7 +682,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
exception.set_body(&fields, false); exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum() exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
} }
}) }),
] ]
.iter() .iter()
.copied() .copied()
@ -659,8 +690,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
// NOTE: special handling of option cannot use this type cache since it contains type var, // NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead // handled inside get_llvm_type instead
let ConcreteTypeEnum::TFunc { args, ret, .. } = let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) else {
task.store.get(task.signature) else {
unreachable!() unreachable!()
}; };
@ -677,7 +707,16 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let ret_type = if unifier.unioned(ret, primitives.none) { let ret_type = if unifier.unioned(ret, primitives.none) {
None None
} else { } else {
Some(get_llvm_abi_type(context, &module, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret)) Some(get_llvm_abi_type(
context,
&module,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
ret,
))
}; };
let has_sret = ret_type.map_or(false, |ty| need_sret(ty)); let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
@ -704,7 +743,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let fn_type = match ret_type { let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false), Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, false) _ => context.void_type().fn_type(&params, false),
}; };
let symbol = &task.symbol_name; let symbol = &task.symbol_name;
@ -719,9 +758,13 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
fn_val.set_personality_function(personality); fn_val.set_personality_function(personality);
} }
if has_sret { if has_sret {
fn_val.add_attribute(AttributeLoc::Param(0), fn_val.add_attribute(
context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), AttributeLoc::Param(0),
ret_type.unwrap().as_any_type_enum())); context.create_type_attribute(
Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum(),
),
);
} }
let init_bb = context.append_basic_block(fn_val, "init"); let init_bb = context.append_basic_block(fn_val, "init");
@ -739,13 +782,10 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
&mut unifier, &mut unifier,
top_level_ctx.as_ref(), top_level_ctx.as_ref(),
&mut type_cache, &mut type_cache,
&primitives,
arg.ty, arg.ty,
); );
let alloca = builder.build_alloca( let alloca =
local_type, builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())).unwrap();
&format!("{}.addr", &arg.name.to_string()),
);
// Remap boolean parameters into i8 // Remap boolean parameters into i8
let param = if local_type.is_int_type() && param.is_int_value() { let param = if local_type.is_int_type() && param.is_int_value() {
@ -756,19 +796,20 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
bool_to_i8(&builder, context, param_val) bool_to_i8(&builder, context, param_val)
} else { } else {
param_val param_val
}.into() }
.into()
} else { } else {
param param
}; };
builder.build_store(alloca, param); builder.build_store(alloca, param).unwrap();
var_assignment.insert(arg.name, (alloca, None, 0)); var_assignment.insert(arg.name, (alloca, None, 0));
} }
let return_buffer = if has_sret { let return_buffer = if has_sret {
Some(fn_val.get_nth_param(0).unwrap().into_pointer_value()) Some(fn_val.get_nth_param(0).unwrap().into_pointer_value())
} else { } else {
fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret")) fn_type.get_return_type().map(|v| builder.build_alloca(v, "$ret").unwrap())
}; };
let static_values = { let static_values = {
@ -780,7 +821,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
*static_val = Some(v); *static_val = Some(v);
} }
builder.build_unconditional_branch(body_bb); builder.build_unconditional_branch(body_bb).unwrap();
builder.position_at_end(body_bb); builder.position_at_end(body_bb);
let (dibuilder, compile_unit) = module.create_debug_info_builder( let (dibuilder, compile_unit) = module.create_debug_info_builder(
@ -789,11 +830,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
/* filename */ /* filename */
&task &task
.body .body
.get(0) .first()
.map_or_else( .map_or_else(|| "<nac3_internal>".to_string(), |f| f.location.file.0.to_string()),
|| "<nac3_internal>".to_string(),
|f| f.location.file.0.to_string(),
),
/* directory */ "", /* directory */ "",
/* producer */ "NAC3", /* producer */ "NAC3",
/* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None, /* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None,
@ -819,7 +857,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
inkwell::debug_info::DIFlags::PUBLIC, inkwell::debug_info::DIFlags::PUBLIC,
); );
let (row, col) = let (row, col) =
task.body.get(0).map_or_else(|| (0, 0), |b| (b.location.row, b.location.column)); task.body.first().map_or_else(|| (0, 0), |b| (b.location.row, b.location.column));
let func_scope: DISubprogram<'_> = dibuilder.create_function( let func_scope: DISubprogram<'_> = dibuilder.create_function(
/* scope */ compile_unit.as_debug_info_scope(), /* scope */ compile_unit.as_debug_info_scope(),
/* func name */ symbol, /* func name */ symbol,
@ -866,7 +904,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
row as u32, row as u32,
col as u32, col as u32,
func_scope.as_debug_info_scope(), func_scope.as_debug_info_scope(),
None None,
); );
code_gen_context.builder.set_current_debug_location(loc); code_gen_context.builder.set_current_debug_location(loc);
@ -874,7 +912,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
// after static analysis, only void functions can have no return at the end. // after static analysis, only void functions can have no return at the end.
if !code_gen_context.is_terminated() { if !code_gen_context.is_terminated() {
code_gen_context.builder.build_return(None); code_gen_context.builder.build_return(None).unwrap();
} }
code_gen_context.builder.unset_current_debug_location(); code_gen_context.builder.unset_current_debug_location();
@ -916,12 +954,14 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
if bool_value.get_type().get_bit_width() == 1 { if bool_value.get_type().get_bit_width() == 1 {
bool_value bool_value
} else { } else {
builder.build_int_compare( builder
.build_int_compare(
IntPredicate::NE, IntPredicate::NE,
bool_value, bool_value,
bool_value.get_type().const_zero(), bool_value.get_type().const_zero(),
"tobool" "tobool",
) )
.unwrap()
} }
} }
@ -929,21 +969,23 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
fn bool_to_i8<'ctx>( fn bool_to_i8<'ctx>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
ctx: &'ctx Context, ctx: &'ctx Context,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let value_bits = bool_value.get_type().get_bit_width(); let value_bits = bool_value.get_type().get_bit_width();
match value_bits { match value_bits {
8 => bool_value, 8 => bool_value,
1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool"), 1 => builder.build_int_z_extend(bool_value, ctx.i8_type(), "frombool").unwrap(),
_ => bool_to_i8( _ => bool_to_i8(
builder, builder,
ctx, ctx,
builder.build_int_compare( builder
.build_int_compare(
IntPredicate::NE, IntPredicate::NE,
bool_value, bool_value,
bool_value.get_type().const_zero(), bool_value.get_type().const_zero(),
"" "",
) )
.unwrap(),
), ),
} }
} }
@ -969,9 +1011,20 @@ fn gen_in_range_check<'ctx>(
stop: IntValue<'ctx>, stop: IntValue<'ctx>,
step: IntValue<'ctx>, step: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), ""); let sign = ctx
let lo = ctx.builder.build_select(sign, value, stop, "").into_int_value(); .builder
let hi = ctx.builder.build_select(sign, stop, value, "").into_int_value(); .build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "")
.unwrap();
let lo = ctx
.builder
.build_select(sign, value, stop, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let hi = ctx
.builder
.build_select(sign, stop, value, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp").unwrap()
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,22 +1,27 @@
use crate::{ use crate::{
codegen::{ codegen::{
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenLLVMOptions, classes::{ListType, NDArrayType, ProxyType, RangeType},
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry, concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
}, },
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{ toplevel::{
composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef, composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
}, },
typecheck::{ typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use indexmap::IndexMap;
use indoc::indoc; use indoc::indoc;
use inkwell::{ use inkwell::{
targets::{InitializationConfig, Target}, targets::{InitializationConfig, Target},
OptimizationLevel OptimizationLevel,
}; };
use nac3parser::ast::FileName;
use nac3parser::{ use nac3parser::{
ast::{fold::Fold, StrRef}, ast::{fold::Fold, StrRef},
parser::parse_program, parser::parse_program,
@ -52,13 +57,13 @@ impl SymbolResolver for Resolver {
_: &PrimitiveStore, _: &PrimitiveStore,
str: StrRef, str: StrRef,
) -> Result<Type, String> { ) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str)) self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
} }
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, 'a>, _: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
@ -67,10 +72,8 @@ impl SymbolResolver for Resolver {
self.id_to_def self.id_to_def
.read() .read()
.get(&id) .get(&id)
.cloned() .copied()
.ok_or_else(|| HashSet::from([ .ok_or_else(|| HashSet::from([format!("cannot find symbol `{id}`")]))
format!("cannot find symbol `{}`", id),
]))
} }
fn get_string_id(&self, _: &str) -> i32 { fn get_string_id(&self, _: &str) -> i32 {
@ -89,9 +92,9 @@ fn test_primitives() {
d = a if c == 1 else 0 d = a if c == 1 else 0
return d return d
"}; "};
let statements = parse_program(source, Default::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let composer: TopLevelComposer = Default::default(); let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -100,7 +103,7 @@ fn test_primitives() {
let resolver = Arc::new(Resolver { let resolver = Arc::new(Resolver {
id_to_type: HashMap::new(), id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(), class_names: HashMap::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
@ -110,7 +113,7 @@ fn test_primitives() {
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None }, FuncArg { name: "b".into(), ty: primitives.int32, default_value: None },
], ],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: VarMap::new(),
}; };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -125,12 +128,12 @@ fn test_primitives() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect(); let mut identifiers: HashSet<_> = ["a".into(), "b".into()].into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
unifier: &mut unifier, unifier: &mut unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
@ -154,7 +157,7 @@ fn test_primitives() {
}); });
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Vec::default(),
symbol_name: "testing".into(), symbol_name: "testing".into(),
body: Arc::new(statements), body: Arc::new(statements),
unifier_index: 0, unifier_index: 0,
@ -225,12 +228,7 @@ fn test_primitives() {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(), target: CodeGenTargetMachineOptions::from_host_triple(),
}; };
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
@ -241,14 +239,14 @@ fn test_simple_call() {
a = foo(a) a = foo(a)
return a * 2 return a * 2
"}; "};
let statements_1 = parse_program(source_1, Default::default()).unwrap(); let statements_1 = parse_program(source_1, FileName::default()).unwrap();
let source_2 = indoc! { " let source_2 = indoc! { "
return a + 1 return a + 1
"}; "};
let statements_2 = parse_program(source_2, Default::default()).unwrap(); let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let composer: TopLevelComposer = Default::default(); let composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -257,7 +255,7 @@ fn test_simple_call() {
let signature = FunSignature { let signature = FunSignature {
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }], args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: VarMap::new(),
}; };
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone())); let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -281,7 +279,7 @@ fn test_simple_call() {
let resolver = Resolver { let resolver = Resolver {
id_to_type: HashMap::new(), id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(), class_names: HashMap::default(),
}; };
resolver.add_id_def("foo".into(), DefinitionId(foo_id)); resolver.add_id_def("foo".into(), DefinitionId(foo_id));
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>; let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
@ -302,12 +300,12 @@ fn test_simple_call() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect(); let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
unifier: &mut unifier, unifier: &mut unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
@ -336,11 +334,11 @@ fn test_simple_call() {
&mut *top_level.definitions.read()[foo_id].write() &mut *top_level.definitions.read()[foo_id].write()
{ {
instance_to_stmt.insert( instance_to_stmt.insert(
"".to_string(), String::new(),
FunInstance { FunInstance {
body: Arc::new(statements_2), body: Arc::new(statements_2),
calls: Arc::new(inferencer.calls.clone()), calls: Arc::new(inferencer.calls.clone()),
subst: Default::default(), subst: IndexMap::default(),
unifier_id: 0, unifier_id: 0,
}, },
); );
@ -356,7 +354,7 @@ fn test_simple_call() {
}); });
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Vec::default(),
symbol_name: "testing".to_string(), symbol_name: "testing".to_string(),
body: Arc::new(statements_1), body: Arc::new(statements_1),
calls: Arc::new(calls1), calls: Arc::new(calls1),
@ -415,12 +413,39 @@ fn test_simple_call() {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(), target: CodeGenTargetMachineOptions::from_host_triple(),
}; };
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
#[test]
fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok());
}
#[test]
fn test_classes_range_type_new() {
let ctx = inkwell::context::Context::create();
let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
}
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}

View File

@ -1,7 +1,26 @@
#![warn(clippy::all)] #![deny(
#![allow(dead_code)] future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
dead_code,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
pub mod codegen; pub mod codegen;
pub mod symbol_resolver; pub mod symbol_resolver;
pub mod toplevel; pub mod toplevel;
pub mod typecheck; pub mod typecheck;
pub mod util;

View File

@ -1,22 +1,18 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display}; use std::{collections::HashMap, collections::HashSet, fmt::Display};
use std::rc::Rc;
use crate::typecheck::typedef::TypeEnum;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::{CodeGenContext, CodeGenerator},
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation}, toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
};
use crate::{
codegen::CodeGenerator,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, Unifier}, typedef::{Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip}; use itertools::{chain, izip, Itertools};
use nac3parser::ast::{Constant, Expr, Location, StrRef}; use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
@ -43,7 +39,7 @@ impl SymbolValue {
constant: &Constant, constant: &Constant,
expected_ty: Type, expected_ty: Type,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
unifier: &mut Unifier unifier: &mut Unifier,
) -> Result<Self, String> { ) -> Result<Self, String> {
match constant { match constant {
Constant::None => { Constant::None => {
@ -66,24 +62,16 @@ impl SymbolValue {
} else { } else {
Err(format!("Expected {expected_ty:?}, but got str")) Err(format!("Expected {expected_ty:?}, but got str"))
} }
}, }
Constant::Int(i) => { Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) { if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i) i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
.map(SymbolValue::I32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) { } else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i) i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
.map(SymbolValue::I64)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) { } else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i) u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
.map(SymbolValue::U32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) { } else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i) u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
.map(SymbolValue::U64)
.map_err(|e| e.to_string())
} else { } else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty))) Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
} }
@ -91,7 +79,10 @@ impl SymbolValue {
Constant::Tuple(t) => { Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty); let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else { let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name())) return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
}; };
assert_eq!(ty.len(), t.len()); assert_eq!(ty.len(), t.len());
@ -109,7 +100,45 @@ impl SymbolValue {
} else { } else {
Err(format!("Expected {expected_ty:?}, but got float")) Err(format!("Expected {expected_ty:?}, but got float"))
} }
}, }
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i)
.map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} else {
u32::try_from(i)
.map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
}
}
Constant::Tuple(t) => {
let elems = t
.iter()
.map(Self::from_constant_inferred)
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
_ => Err(format!("Unsupported value type {constant:?}")), _ => Err(format!("Unsupported value type {constant:?}")),
} }
} }
@ -125,28 +154,27 @@ impl SymbolValue {
SymbolValue::Double(_) => primitives.float, SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool, SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
.iter() unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
.map(|v| v.get_type(primitives, unifier))
.collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple {
ty: vs_tys,
})
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
} }
} }
/// Returns the [`TypeAnnotation`] representing the data type of this value. /// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { pub fn get_type_annotation(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
match self { match self {
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool), SymbolValue::Bool(..)
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float), | SymbolValue::Double(..)
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32), | SymbolValue::I32(..)
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64), | SymbolValue::I64(..)
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32), | SymbolValue::U32(..)
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64), | SymbolValue::U64(..)
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str), | SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs let vs_tys = vs
.iter() .iter()
@ -155,13 +183,13 @@ impl SymbolValue {
TypeAnnotation::Tuple(vs_tys) TypeAnnotation::Tuple(vs_tys)
} }
SymbolValue::OptionNone => TypeAnnotation::CustomClass { SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier), id: primitives.option.obj_id(unifier).unwrap(),
params: Vec::default(), params: Vec::default(),
}, },
SymbolValue::OptionSome(v) => { SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier); let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: primitives.option.get_obj_id(unifier), id: primitives.option.obj_id(unifier).unwrap(),
params: vec![ty], params: vec![ty],
} }
} }
@ -169,7 +197,11 @@ impl SymbolValue {
} }
/// Returns the [`TypeEnum`] representing the data type of this value. /// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> { pub fn get_type_enum(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier); let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty) unifier.get_ty(ty)
} }
@ -200,6 +232,38 @@ impl Display for SymbolValue {
} }
} }
impl TryFrom<SymbolValue> for u64 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
/// numeric or if the value cannot be converted into a `u64` without overflow.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::U64(v) => Ok(v),
_ => Err(()),
}
}
}
impl TryFrom<SymbolValue> for i128 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
/// numeric.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(i128::from(v)),
SymbolValue::I64(v) => Ok(i128::from(v)),
SymbolValue::U32(v) => Ok(i128::from(v)),
SymbolValue::U64(v) => Ok(i128::from(v)),
_ => Err(()),
}
}
}
pub trait StaticValue { pub trait StaticValue {
/// Returns a unique identifier for this value. /// Returns a unique identifier for this value.
fn get_unique_identifier(&self) -> u64; fn get_unique_identifier(&self) -> u64;
@ -232,10 +296,10 @@ pub trait StaticValue {
#[derive(Clone)] #[derive(Clone)]
pub enum ValueEnum<'ctx> { pub enum ValueEnum<'ctx> {
/// [ValueEnum] representing a static value. /// [`ValueEnum`] representing a static value.
Static(Arc<dyn StaticValue + Send + Sync>), Static(Arc<dyn StaticValue + Send + Sync>),
/// [ValueEnum] representing a dynamic value. /// [`ValueEnum`] representing a dynamic value.
Dynamic(BasicValueEnum<'ctx>), Dynamic(BasicValueEnum<'ctx>),
} }
@ -270,7 +334,6 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
} }
impl<'ctx> ValueEnum<'ctx> { impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`]. /// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>( pub fn to_basic_value_enum<'a>(
self, self,
@ -312,7 +375,7 @@ pub trait SymbolResolver {
&self, &self,
_unifier: &mut Unifier, _unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>], _top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore _primitives: &PrimitiveStore,
) -> Result<(), String> { ) -> Result<(), String> {
Ok(()) Ok(())
} }
@ -325,12 +388,12 @@ thread_local! {
"float".into(), "float".into(),
"bool".into(), "bool".into(),
"virtual".into(), "virtual".into(),
"list".into(),
"tuple".into(), "tuple".into(),
"str".into(), "str".into(),
"Exception".into(), "Exception".into(),
"uint32".into(), "uint32".into(),
"uint64".into(), "uint64".into(),
"Literal".into(),
]; ];
} }
@ -349,12 +412,12 @@ pub fn parse_type_annotation<T>(
let float_id = ids[2]; let float_id = ids[2];
let bool_id = ids[3]; let bool_id = ids[3];
let virtual_id = ids[4]; let virtual_id = ids[4];
let list_id = ids[5]; let tuple_id = ids[5];
let tuple_id = ids[6]; let str_id = ids[6];
let str_id = ids[7]; let exn_id = ids[7];
let exn_id = ids[8]; let uint32_id = ids[8];
let uint32_id = ids[9]; let uint64_id = ids[9];
let uint64_id = ids[10]; let literal_id = ids[10];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id { if *id == int32_id {
@ -379,40 +442,29 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() { if !type_vars.is_empty() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"Unexpected number of type parameters: expected {} but got 0", "Unexpected number of type parameters: expected {} but got 0",
type_vars.len() type_vars.len()
), )]));
]))
} }
let fields = chain( let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))), fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))), methods.iter().map(|(k, v, _)| (*k, (*v, false))),
) )
.collect(); .collect();
Ok(unifier.add_ty(TypeEnum::TObj { Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
obj_id,
fields,
params: HashMap::default(),
}))
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
format!("Cannot use function name as type at {loc}"),
]))
} }
} else { } else {
let ty = resolver let ty =
.get_symbol_type(unifier, top_level_defs, primitives, *id) resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
.map_err(|e| HashSet::from([ |e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
format!("Unknown type annotation at {loc}: {e}"), )?;
]))?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty) Ok(ty)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
format!("Unknown type annotation {id} at {loc}"),
]))
} }
} }
} }
@ -422,9 +474,6 @@ pub fn parse_type_annotation<T>(
if *id == virtual_id { if *id == virtual_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == list_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
} else if *id == tuple_id { } else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node { if let Tuple { elts, .. } = &slice.node {
let ty = elts let ty = elts
@ -435,10 +484,31 @@ pub fn parse_type_annotation<T>(
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty })) Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["Expected multiple elements for tuple".into()]))
"Expected multiple elements for tuple".into()
]))
} }
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([format!(
"Expected literal in type argument for Literal at {}",
elt.location
)])),
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}
.into_iter()
.flatten()
.collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else { } else {
let types = if let Tuple { elts, .. } = &slice.node { let types = if let Tuple { elts, .. } = &slice.node {
elts.iter() elts.iter()
@ -454,15 +524,13 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() { if types.len() != type_vars.len() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"Unexpected number of type parameters: expected {} but got {}", "Unexpected number of type parameters: expected {} but got {}",
type_vars.len(), type_vars.len(),
types.len() types.len()
), )]));
]))
} }
let mut subst = HashMap::new(); let mut subst = VarMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) { for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id *id
@ -484,9 +552,7 @@ pub fn parse_type_annotation<T>(
})); }));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["Cannot use function name as type".into()]))
"Cannot use function name as type".into(),
]))
} }
} }
}; };
@ -497,14 +563,13 @@ pub fn parse_type_annotation<T>(
if let Name { id, .. } = &value.node { if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier) subscript_name_handle(id, slice, unifier)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
format!("unsupported type expression at {}", expr.location),
]))
} }
} }
_ => Err(HashSet::from([ Constant { value, .. } => SymbolValue::from_constant_inferred(value)
format!("unsupported type expression at {}", expr.location), .map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
])), .map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
} }
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,298 @@
use std::convert::TryInto; use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{into_var_map, iter_type_vars, Mapping, TypeVarId, VarMap};
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use super::*; use super::*;
/// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
pub enum PrimDef {
// Classes
Int32,
Int64,
Float,
Bool,
None,
Range,
Str,
Exception,
UInt32,
UInt64,
Option,
List,
NDArray,
// Member Functions
OptionIsSome,
OptionIsNone,
OptionUnwrap,
NDArrayCopy,
NDArrayFill,
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunNpNDArray,
FunNpEmpty,
FunNpZeros,
FunNpOnes,
FunNpFull,
FunNpArray,
FunNpEye,
FunNpIdentity,
FunRound,
FunRound64,
FunNpRound,
FunRangeInit,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil,
FunLen,
FunMin,
FunNpMin,
FunNpMinimum,
FunMax,
FunNpMax,
FunNpMaximum,
FunAbs,
FunNpIsNan,
FunNpIsInf,
FunNpSin,
FunNpCos,
FunNpExp,
FunNpExp2,
FunNpLog,
FunNpLog10,
FunNpLog2,
FunNpFabs,
FunNpSqrt,
FunNpRint,
FunNpTan,
FunNpArcsin,
FunNpArccos,
FunNpArctan,
FunNpSinh,
FunNpCosh,
FunNpTanh,
FunNpArcsinh,
FunNpArccosh,
FunNpArctanh,
FunNpExpm1,
FunNpCbrt,
FunSpSpecErf,
FunSpSpecErfc,
FunSpSpecGamma,
FunSpSpecGammaln,
FunSpSpecJ0,
FunSpSpecJ1,
FunNpArctan2,
FunNpCopysign,
FunNpFmax,
FunNpFmin,
FunNpLdExp,
FunNpHypot,
FunNpNextAfter,
// Top-Level Functions
FunSome,
}
/// Associated details of a [`PrimDef`]
pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str },
}
impl PrimDef {
/// Get the assigned [`DefinitionId`] of this [`PrimDef`].
///
/// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at,
/// with the first `PrimDef`'s definition id being `0`.
#[must_use]
pub fn id(&self) -> DefinitionId {
DefinitionId(*self as usize)
}
/// Check if a definition ID is that of a [`PrimDef`].
#[must_use]
pub fn contains_id(id: DefinitionId) -> bool {
Self::iter().any(|prim| prim.id() == id)
}
/// Get the definition "simple name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`].
///
/// If the [`PrimDef`] is a class, this returns [`None`].
#[must_use]
pub fn simple_name(&self) -> &'static str {
match self.details() {
PrimDefDetails::PrimFunction { simple_name, .. } => simple_name,
PrimDefDetails::PrimClass { .. } => {
panic!("PrimDef {self:?} has no simple_name as it is not a function.")
}
}
}
/// Get the definition "name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`].
///
/// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`].
#[must_use]
pub fn name(&self) -> &'static str {
match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
}
}
/// Get the associated details of this [`PrimDef`]
#[must_use]
pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str) -> PrimDefDetails {
PrimDefDetails::PrimClass { name }
}
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
PrimDefDetails::PrimFunction { simple_name: simple_name.unwrap_or(name), name }
}
match self {
PrimDef::Int32 => class("int32"),
PrimDef::Int64 => class("int64"),
PrimDef::Float => class("float"),
PrimDef::Bool => class("bool"),
PrimDef::None => class("none"),
PrimDef::Range => class("range"),
PrimDef::Str => class("str"),
PrimDef::Exception => class("Exception"),
PrimDef::UInt32 => class("uint32"),
PrimDef::UInt64 => class("uint64"),
PrimDef::Option => class("Option"),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::List => class("list"),
PrimDef::NDArray => class("ndarray"),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None),
PrimDef::FunNpOnes => fun("np_ones", None),
PrimDef::FunNpFull => fun("np_full", None),
PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunAbs => fun("abs", None),
PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None),
PrimDef::FunNpCos => fun("np_cos", None),
PrimDef::FunNpExp => fun("np_exp", None),
PrimDef::FunNpExp2 => fun("np_exp2", None),
PrimDef::FunNpLog => fun("np_log", None),
PrimDef::FunNpLog10 => fun("np_log10", None),
PrimDef::FunNpLog2 => fun("np_log2", None),
PrimDef::FunNpFabs => fun("np_fabs", None),
PrimDef::FunNpSqrt => fun("np_sqrt", None),
PrimDef::FunNpRint => fun("np_rint", None),
PrimDef::FunNpTan => fun("np_tan", None),
PrimDef::FunNpArcsin => fun("np_arcsin", None),
PrimDef::FunNpArccos => fun("np_arccos", None),
PrimDef::FunNpArctan => fun("np_arctan", None),
PrimDef::FunNpSinh => fun("np_sinh", None),
PrimDef::FunNpCosh => fun("np_cosh", None),
PrimDef::FunNpTanh => fun("np_tanh", None),
PrimDef::FunNpArcsinh => fun("np_arcsinh", None),
PrimDef::FunNpArccosh => fun("np_arccosh", None),
PrimDef::FunNpArctanh => fun("np_arctanh", None),
PrimDef::FunNpExpm1 => fun("np_expm1", None),
PrimDef::FunNpCbrt => fun("np_cbrt", None),
PrimDef::FunSpSpecErf => fun("sp_spec_erf", None),
PrimDef::FunSpSpecErfc => fun("sp_spec_erfc", None),
PrimDef::FunSpSpecGamma => fun("sp_spec_gamma", None),
PrimDef::FunSpSpecGammaln => fun("sp_spec_gammaln", None),
PrimDef::FunSpSpecJ0 => fun("sp_spec_j0", None),
PrimDef::FunSpSpecJ1 => fun("sp_spec_j1", None),
PrimDef::FunNpArctan2 => fun("np_arctan2", None),
PrimDef::FunNpCopysign => fun("np_copysign", None),
PrimDef::FunNpFmax => fun("np_fmax", None),
PrimDef::FunNpFmin => fun("np_fmin", None),
PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunSome => fun("Some", None),
}
}
}
/// Asserts that a [`PrimDef`] is in an allowlist.
///
/// Like `debug_assert!`, this statements of this function are only
/// enabled if `cfg!(debug_assertions)` is true.
pub fn debug_assert_prim_is_allowed(prim: PrimDef, allowlist: &[PrimDef]) {
if cfg!(debug_assertions) {
let allowed = allowlist.iter().any(|p| *p == prim);
assert!(
allowed,
"Disallowed primitive definition. Got {prim:?}, but expects it to be in {allowlist:?}"
);
}
}
/// Construct the fields of class `Exception`
/// See [`TypeEnum::TObj::fields`] and [`TopLevelDef::Class::fields`]
#[must_use]
pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef, Type, bool)> {
vec![
("__name__".into(), int32, true),
("__file__".into(), str, true),
("__line__".into(), int32, true),
("__col__".into(), int32, true),
("__func__".into(), str, true),
("__message__".into(), str, true),
("__param0__".into(), int64, true),
("__param1__".into(), int64, true),
("__param2__".into(), int64, true),
]
}
impl TopLevelDef { impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String { pub fn to_string(&self, unifier: &mut Unifier) -> String {
match self { match self {
@ -44,94 +332,133 @@ impl TopLevelDef {
impl TopLevelComposer { impl TopLevelComposer {
#[must_use] #[must_use]
pub fn make_primitives() -> (PrimitiveStore, Unifier) { pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: [
params: HashMap::new(), ("start".into(), (int32, true)),
}); ("stop".into(), (int32, true)),
let str = unifier.add_ty(TypeEnum::TObj { ("step".into(), (int32, true)),
obj_id: DefinitionId(6),
fields: HashMap::new(),
params: HashMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7),
fields: vec![
("__name__".into(), (int32, true)),
("__file__".into(), (str, true)),
("__line__".into(), (int32, true)),
("__col__".into(), (int32, true)),
("__func__".into(), (str, true)),
("__message__".into(), (str, true)),
("__param0__".into(), (int64, true)),
("__param1__".into(), (int64, true)),
("__param2__".into(), (int64, true)),
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect(),
params: HashMap::new(), params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Str.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Exception.id(),
fields: make_exception_fields(int32, int64, str)
.into_iter()
.map(|(name, ty, mutable)| (name, (ty, mutable)))
.collect(),
params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(8), obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(9), obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None); let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None);
let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: bool, ret: bool,
vars: HashMap::from([(option_type_var.1, option_type_var.0)]), vars: into_var_map([option_type_var]),
})); }));
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: option_type_var.0, ret: option_type_var.ty,
vars: HashMap::from([(option_type_var.1, option_type_var.0)]), vars: into_var_map([option_type_var]),
})); }));
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(10), obj_id: PrimDef::Option.id(),
fields: vec![ fields: vec![
("is_some".into(), (is_some_type_fun_ty, true)), (PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
("is_none".into(), (is_some_type_fun_ty, true)), (PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
("unwrap".into(), (unwrap_fun_ty, true)), (PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: HashMap::from([(option_type_var.1, option_type_var.0)]), params: into_var_map([option_type_var]),
}); });
let size_t_ty = match size_t {
32 => uint32,
64 => uint64,
_ => unreachable!(),
};
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: Mapping::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
ret: ndarray_copy_fun_ret_ty.ty,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}));
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "value".into(),
ty: ndarray_dtype_tvar.ty,
default_value: None,
}],
ret: none,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}));
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([
(PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
});
unifier.unify(ndarray_copy_fun_ret_ty.ty, ndarray).unwrap();
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
int64, int64,
@ -144,7 +471,11 @@ impl TopLevelComposer {
str, str,
exception, exception,
option, option,
list,
ndarray,
size_t,
}; };
unifier.put_primitive_store(&primitives);
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
(primitives, unifier) (primitives, unifier)
} }
@ -153,7 +484,7 @@ impl TopLevelComposer {
/// when first registering, the `type_vars`, fields, methods, ancestors are invalid /// when first registering, the `type_vars`, fields, methods, ancestors are invalid
#[must_use] #[must_use]
pub fn make_top_level_class_def( pub fn make_top_level_class_def(
index: usize, obj_id: DefinitionId,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
name: StrRef, name: StrRef,
constructor: Option<Type>, constructor: Option<Type>,
@ -161,9 +492,10 @@ impl TopLevelComposer {
) -> TopLevelDef { ) -> TopLevelDef {
TopLevelDef::Class { TopLevelDef::Class {
name, name,
object_id: DefinitionId(index), object_id: obj_id,
type_vars: Vec::default(), type_vars: Vec::default(),
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(), methods: Vec::default(),
ancestors: Vec::default(), ancestors: Vec::default(),
constructor, constructor,
@ -272,13 +604,11 @@ impl TopLevelComposer {
} }
/// get the `var_id` of a given `TVar` type /// get the `var_id` of a given `TVar` type
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<u32, HashSet<String>> { pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<TypeVarId, HashSet<String>> {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
Ok(*id) Ok(*id)
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["not type var".to_string()]))
"not type var".to_string(),
]))
} }
} }
@ -295,12 +625,14 @@ impl TopLevelComposer {
let ( let (
TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }),
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }),
) = (this, other) else { ) = (this, other)
else {
unreachable!("this function must be called with function type") unreachable!("this function must be called with function type")
}; };
// check args // check args
let args_ok = this_args let args_ok =
this_args
.iter() .iter()
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| { .zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
@ -356,12 +688,10 @@ impl TopLevelComposer {
} }
} => } =>
{ {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"redundant type annotation for class fields at {}", "redundant type annotation for class fields at {}",
s.location s.location
), )]))
]))
} }
ast::StmtKind::Assign { targets, .. } => { ast::StmtKind::Assign { targets, .. } => {
for t in targets { for t in targets {
@ -438,7 +768,7 @@ impl TopLevelComposer {
TypeAnnotation::CustomClass { id: e_id, params: e_param }, TypeAnnotation::CustomClass { id: e_id, params: e_param },
) => { ) => {
*f_id == *e_id *f_id == *e_id
&& *f_id == primitive.option.get_obj_id(unifier) && *f_id == primitive.option.obj_id(unifier).unwrap()
&& (f_param.is_empty() && (f_param.is_empty()
|| (f_param.len() == 1 || (f_param.len() == 1
&& e_param.len() == 1 && e_param.len() == 1
@ -485,93 +815,126 @@ pub fn parse_parameter_default_value(
Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( Constant::Tuple(tuple) => Ok(SymbolValue::Tuple(
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?, tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?,
)), )),
Constant::None => Err(HashSet::from([ Constant::None => Err(HashSet::from([format!(
format!(
"`None` is not supported, use `none` for option type instead ({loc})" "`None` is not supported, use `none` for option type instead ({loc})"
), )])),
])),
_ => unimplemented!("this constant is not supported at {}", loc), _ => unimplemented!("this constant is not supported at {}", loc),
} }
} }
match &default.node { match &default.node {
ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location),
ast::ExprKind::Call { func, args, .. } if args.len() == 1 => { ast::ExprKind::Call { func, args, .. } if args.len() == 1 => match &func.node {
match &func.node {
ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => { ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<i64, _> = (*v).try_into(); let v: Result<i64, _> = (*v).try_into();
match v { match v {
Ok(v) => Ok(SymbolValue::I64(v)), Ok(v) => Ok(SymbolValue::I64(v)),
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("default param value out of range at {}", default.location) "default param value out of range at {}",
])), default.location
)])),
} }
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("only allow constant integer here at {}", default.location), "only allow constant integer here at {}",
])) default.location
} )])),
},
ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => { ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u32, _> = (*v).try_into(); let v: Result<u32, _> = (*v).try_into();
match v { match v {
Ok(v) => Ok(SymbolValue::U32(v)), Ok(v) => Ok(SymbolValue::U32(v)),
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("default param value out of range at {}", default.location), "default param value out of range at {}",
])), default.location
)])),
} }
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("only allow constant integer here at {}", default.location), "only allow constant integer here at {}",
])) default.location
} )])),
},
ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => { ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u64, _> = (*v).try_into(); let v: Result<u64, _> = (*v).try_into();
match v { match v {
Ok(v) => Ok(SymbolValue::U64(v)), Ok(v) => Ok(SymbolValue::U64(v)),
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("default param value out of range at {}", default.location), "default param value out of range at {}",
])), default.location
)])),
} }
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("only allow constant integer here at {}", default.location), "only allow constant integer here at {}",
])) default.location
} )])),
ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok( },
SymbolValue::OptionSome( ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(SymbolValue::OptionSome(
Box::new(parse_parameter_default_value(&args[0], resolver)?) Box::new(parse_parameter_default_value(&args[0], resolver)?),
) )),
), _ => Err(HashSet::from([format!(
_ => Err(HashSet::from([ "unsupported default parameter at {}",
format!("unsupported default parameter at {}", default.location), default.location
])), )])),
} },
} ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts elts.iter()
.iter()
.map(|x| parse_parameter_default_value(x, resolver)) .map(|x| parse_parameter_default_value(x, resolver))
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?,
)), )),
ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone), ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone),
ast::ExprKind::Name { id, .. } => { ast::ExprKind::Name { id, .. } => {
resolver.get_default_param_value(default).ok_or_else( resolver.get_default_param_value(default).ok_or_else(|| {
|| HashSet::from([ HashSet::from([format!(
format!(
"`{}` cannot be used as a default parameter at {} \ "`{}` cannot be used as a default parameter at {} \
(not primitive type, option or tuple / not defined?)", (not primitive type, option or tuple / not defined?)",
id, id, default.location
default.location )])
), })
])
)
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!(
"unsupported default parameter (not primitive type, option or tuple) at {}", "unsupported default parameter (not primitive type, option or tuple) at {}",
default.location default.location
), )])),
])) }
}
/// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
arraylike_flatten_element_type(unifier, iter_type_vars(params).next().unwrap().ty)
}
_ => ty,
}
}
/// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
};
if values.len() > 1 {
todo!("Getting num of dimensions for ndarray with more than one ndim bound is unimplemented")
}
u64::try_from(values[0].clone()).unwrap()
}
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
arraylike_get_ndims(unifier, iter_type_vars(params).next().unwrap().ty) + 1
}
_ => 0,
} }
} }

View File

@ -8,14 +8,19 @@ use std::{
use super::codegen::CodeGenContext; use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier}; use super::typecheck::typedef::{
FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap,
};
use crate::{ use crate::{
codegen::CodeGenerator, codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{type_inferencer::CodeLocation, typedef::CallId}, typecheck::{
type_inferencer::CodeLocation,
typedef::{CallId, TypeVarId},
},
}; };
use inkwell::values::BasicValueEnum; use inkwell::values::BasicValueEnum;
use itertools::{izip, Itertools}; use itertools::Itertools;
use nac3parser::ast::{self, Location, Stmt, StrRef}; use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
@ -25,14 +30,14 @@ pub struct DefinitionId(pub usize);
pub mod builtins; pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod numpy;
pub mod type_annotation; pub mod type_annotation;
use composer::*; use composer::*;
use type_annotation::*; use type_annotation::*;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
type GenCallCallback = Box< type GenCallCallback = dyn for<'ctx, 'a> Fn(
dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>, &mut CodeGenContext<'ctx, 'a>,
Option<(Type, ValueEnum<'ctx>)>, Option<(Type, ValueEnum<'ctx>)>,
(&FunSignature, DefinitionId), (&FunSignature, DefinitionId),
@ -40,19 +45,25 @@ type GenCallCallback = Box<
&mut dyn CodeGenerator, &mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String> ) -> Result<Option<BasicValueEnum<'ctx>>, String>
+ Send + Send
+ Sync, + Sync;
>;
pub struct GenCall { pub struct GenCall {
fp: GenCallCallback, fp: Box<GenCallCallback>,
} }
impl GenCall { impl GenCall {
#[must_use] #[must_use]
pub fn new(fp: GenCallCallback) -> GenCall { pub fn new(fp: Box<GenCallCallback>) -> GenCall {
GenCall { fp } GenCall { fp }
} }
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`.
#[must_use]
pub fn create_dummy(reason: String) -> GenCall {
Self::new(Box::new(move |_, _, _, _, _| unreachable!("{reason}")))
}
pub fn run<'ctx>( pub fn run<'ctx>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -75,7 +86,7 @@ impl Debug for GenCall {
pub struct FunInstance { pub struct FunInstance {
pub body: Arc<Vec<Stmt<Option<Type>>>>, pub body: Arc<Vec<Stmt<Option<Type>>>>,
pub calls: Arc<HashMap<CodeLocation, CallId>>, pub calls: Arc<HashMap<CodeLocation, CallId>>,
pub subst: HashMap<u32, Type>, pub subst: VarMap,
pub unifier_id: usize, pub unifier_id: usize,
} }
@ -84,7 +95,7 @@ pub enum TopLevelDef {
Class { Class {
/// Name for error messages and symbols. /// Name for error messages and symbols.
name: StrRef, name: StrRef,
/// Object ID used for [TypeEnum]. /// Object ID used for [`TypeEnum`].
object_id: DefinitionId, object_id: DefinitionId,
/// type variables bounded to the class. /// type variables bounded to the class.
type_vars: Vec<Type>, type_vars: Vec<Type>,
@ -92,6 +103,10 @@ pub enum TopLevelDef {
/// ///
/// Name and type is mutable. /// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>, fields: Vec<(StrRef, Type, bool)>,
/// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition. /// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>, methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself. /// Ancestor classes, including itself.
@ -111,7 +126,7 @@ pub enum TopLevelDef {
/// Function signature. /// Function signature.
signature: Type, signature: Type,
/// Instantiated type variable IDs. /// Instantiated type variable IDs.
var_id: Vec<u32>, var_id: Vec<TypeVarId>,
/// Function instance to symbol mapping /// Function instance to symbol mapping
/// ///
/// * Key: String representation of type variable values, sorted by variable ID in ascending /// * Key: String representation of type variable values, sorted by variable ID in ascending

View File

@ -0,0 +1,85 @@
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray;
}
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [22]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar11]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar11\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,9 +5,9 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [24]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [29]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
] ]

View File

@ -3,11 +3,11 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar10, typevar11]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar10\", \"typevar11\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
] ]

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [30]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [38]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
] ]

View File

@ -1,3 +1,6 @@
use super::*;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::into_var_map;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
@ -8,13 +11,12 @@ use crate::{
}, },
}; };
use indoc::indoc; use indoc::indoc;
use nac3parser::ast::FileName;
use nac3parser::{ast::fold::Fold, parser::parse_program}; use nac3parser::{ast::fold::Fold, parser::parse_program};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use test_case::test_case; use test_case::test_case;
use super::*;
struct ResolverInternal { struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>, id_to_type: Mutex<HashMap<StrRef, Type>>,
id_to_def: Mutex<HashMap<StrRef, DefinitionId>>, id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
@ -52,20 +54,24 @@ impl SymbolResolver for Resolver {
.id_to_type .id_to_type
.lock() .lock()
.get(&str) .get(&str)
.cloned() .copied()
.ok_or_else(|| format!("cannot find symbol `{}`", str)) .ok_or_else(|| format!("cannot find symbol `{str}`"))
} }
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, 'a>, _: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).cloned() self.0
.id_to_def
.lock()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()])) .ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
} }
@ -111,13 +117,13 @@ impl SymbolResolver for Resolver {
"register" "register"
)] )]
fn test_simple_register(source: Vec<&str>) { fn test_simple_register(source: Vec<&str>) {
let mut composer: TopLevelComposer = Default::default(); let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
composer.register_top_level(ast, None, "".into(), false).unwrap(); composer.register_top_level(ast, None, "", false).unwrap();
} }
} }
@ -131,14 +137,14 @@ fn test_simple_register(source: Vec<&str>) {
"register" "register"
)] )]
fn test_simple_register_without_constructor(source: &str) { fn test_simple_register_without_constructor(source: &str) {
let mut composer: TopLevelComposer = Default::default(); let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, Default::default()).unwrap(); let ast = parse_program(source, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
composer.register_top_level(ast, None, "".into(), true).unwrap(); composer.register_top_level(ast, None, "", true).unwrap();
} }
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
def fun(a: int32) -> int32: def fun(a: int32) -> int32:
return a return a
@ -152,35 +158,35 @@ fn test_simple_register_without_constructor(source: &str) {
return 3 return 3
"}, "},
], ],
vec![ &[
"fn[[a:0], 0]", "fn[[a:0], 0]",
"fn[[a:2], 4]", "fn[[a:2], 4]",
"fn[[b:1], 0]", "fn[[b:1], 0]",
], ],
vec![ &[
"fun", "fun",
"foo", "foo",
"f" "f"
]; ];
"function compose" "function compose"
)] )]
fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) { fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer: TopLevelComposer = Default::default(); let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal { let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Default::default(), id_to_def: Mutex::default(),
id_to_type: Default::default(), id_to_type: Mutex::default(),
class_names: Default::default(), class_names: Mutex::default(),
}); });
let resolver = let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id, ty) = let (id, def_id, ty) =
composer.register_top_level(ast, Some(resolver.clone()), "".into(), false).unwrap(); composer.register_top_level(ast, Some(resolver.clone()), "", false).unwrap();
internal_resolver.add_id_def(id, def_id); internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty); internal_resolver.add_id_type(id, ty);
@ -206,7 +212,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
} }
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(): class A():
a: int32 a: int32
@ -239,11 +245,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"simple class compose" "simple class compose"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class Generic_A(Generic[V], B): class Generic_A(Generic[V], B):
a: int64 a: int64
@ -261,11 +267,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"generic class" "generic class"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]: def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]:
pass pass
@ -290,11 +296,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"list tuple generic" "list tuple generic"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T, V]): class A(Generic[T, V]):
a: A[float, bool] a: A[float, bool]
@ -315,11 +321,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"self1" "self1"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
a: int32 a: int32
@ -349,11 +355,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"inheritance_override" "inheritance_override"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
def __init__(self): def __init__(self):
@ -362,11 +368,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["application of type vars to generic class is not currently supported (at unknown:4:24)"]; &["application of type vars to generic class is not currently supported (at unknown:4:24)"];
"err no type var in generic app" "err no type var in generic app"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(B): class A(B):
def __init__(self): def __init__(self):
@ -378,11 +384,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["cyclic inheritance detected"]; &["cyclic inheritance detected"];
"cyclic1" "cyclic1"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(B[bool, int64]): class A(B[bool, int64]):
def __init__(self): def __init__(self):
@ -399,30 +405,30 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"}, "},
], ],
vec!["cyclic inheritance detected"]; &["cyclic inheritance detected"];
"cyclic2" "cyclic2"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A: class A:
pass pass
"} "}
], ],
vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"]; &["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
"simple pass in class" "simple pass in class"
)] )]
#[test_case( #[test_case(
vec![indoc! {" &[indoc! {"
class A: class A:
def __init__(): def __init__():
pass pass
"}], "}],
vec!["__init__ method must have a `self` parameter (at unknown:2:5)"]; &["__init__ method must have a `self` parameter (at unknown:2:5)"];
"err no self_1" "err no self_1"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(B, Generic[T], C): class A(B, Generic[T], C):
def __init__(self): def __init__(self):
@ -440,11 +446,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
"} "}
], ],
vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"]; &["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
"err multiple inheritance" "err multiple inheritance"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
a: int32 a: int32
@ -465,11 +471,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["method fun has same name as ancestors' method, but incompatible type"]; &["method fun has same name as ancestors' method, but incompatible type"];
"err_incompatible_inheritance_method" "err_incompatible_inheritance_method"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
a: int32 a: int32
@ -491,11 +497,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["field `a` has already declared in the ancestor classes"]; &["field `a` has already declared in the ancestor classes"];
"err_incompatible_inheritance_field" "err_incompatible_inheritance_field"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A: class A:
def __init__(self): def __init__(self):
@ -508,12 +514,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["duplicate definition of class `A` (at unknown:1:1)"]; &["duplicate definition of class `A` (at unknown:1:1)"];
"class same name" "class same name"
)] )]
fn test_analyze(source: Vec<&str>, res: Vec<&str>) { fn test_analyze(source: &[&str], res: &[&str]) {
let print = false; let print = false;
let mut composer: TopLevelComposer = Default::default(); let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![
@ -528,15 +534,15 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id, ty) = { let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) { match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
Ok(x) => x, Ok(x) => x,
Err(msg) => { Err(msg) => {
if print { if print {
println!("{}", msg); println!("{msg}");
} else { } else {
assert_eq!(res[0], msg); assert_eq!(res[0], msg);
} }
@ -582,7 +588,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return fib(n - 1) return fib(n - 1)
"} "}
], ],
vec![]; &[];
"simple function" "simple function"
)] )]
#[test_case( #[test_case(
@ -615,7 +621,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return a.fun() + 2 return a.fun() + 2
"} "}
], ],
vec![]; &[];
"simple class body" "simple class body"
)] )]
#[test_case( #[test_case(
@ -640,7 +646,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return [a, b] return [a, b]
"} "}
], ],
vec![]; &[];
"type var fun" "type var fun"
)] )]
#[test_case( #[test_case(
@ -661,7 +667,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return ret if self.b else self.fun(self.a) return ret if self.b else self.fun(self.a)
"} "}
], ],
vec![]; &[];
"type var class" "type var class"
)] )]
#[test_case( #[test_case(
@ -685,12 +691,12 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
self.b = True self.b = True
"} "}
], ],
vec![]; &[];
"no_init_inst_check" "no_init_inst_check"
)] )]
fn test_inference(source: Vec<&str>, res: Vec<&str>) { fn test_inference(source: Vec<&str>, res: &[&str]) {
let print = true; let print = true;
let mut composer: TopLevelComposer = Default::default(); let mut composer = TopLevelComposer::new(Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![
@ -712,15 +718,15 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id, ty) = { let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "".into(), false) { match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
Ok(x) => x, Ok(x) => x,
Err(msg) => { Err(msg) => {
if print { if print {
println!("{}", msg); println!("{msg}");
} else { } else {
assert_eq!(res[0], msg); assert_eq!(res[0], msg);
} }
@ -743,9 +749,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
} else { } else {
// skip 5 to skip primitives // skip 5 to skip primitives
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier }; let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier };
for (_i, (def, _)) in for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
let def = &*def.read(); let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def { if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
@ -754,7 +758,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
name, name,
instance_to_stmt.len() instance_to_stmt.len()
); );
for inst in instance_to_stmt.iter() { for inst in instance_to_stmt {
let ast = &inst.1.body; let ast = &inst.1.body;
for b in ast.iter() { for b in ast.iter() {
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap()); println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
@ -772,22 +776,29 @@ fn make_internal_resolver_with_tvar(
unifier: &mut Unifier, unifier: &mut Unifier,
print: bool, print: bool,
) -> Arc<ResolverInternal> { ) -> Arc<ResolverInternal> {
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let res: Arc<ResolverInternal> = ResolverInternal { let res: Arc<ResolverInternal> = ResolverInternal {
id_to_def: Default::default(), id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])),
id_to_type: tvars id_to_type: tvars
.into_iter() .into_iter()
.map(|(name, range)| { .map(|(name, range)| {
(name, { (name, {
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None); let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
if print { if print {
println!("{}: {:?}, typevar{}", name, ty, id); println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id);
} }
ty tvar.ty
}) })
}) })
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()
.into(), .into(),
class_names: Default::default(), class_names: Mutex::new(HashMap::from([("list".into(), list)])),
} }
.into(); .into();
if print { if print {
@ -807,8 +818,8 @@ impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
Ok(if let Some(ty) = user { Ok(if let Some(ty) = user {
self.unifier.internal_stringify( self.unifier.internal_stringify(
ty, ty,
&mut |id| format!("class{}", id.to_string()), &mut |id| format!("class{id}"),
&mut |id| format!("typevar{}", id.to_string()), &mut |id| format!("typevar{id}"),
&mut None, &mut None,
) )
} else { } else {

View File

@ -1,5 +1,8 @@
use crate::symbol_resolver::SymbolValue;
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
@ -13,17 +16,8 @@ pub enum TypeAnnotation {
// can only be CustomClassKind // can only be CustomClassKind
Virtual(Box<TypeAnnotation>), Virtual(Box<TypeAnnotation>),
TypeVar(Type), TypeVar(Type),
/// A constant used in the context of a const-generic variable. /// A `Literal` allowing a subset of literals.
Constant { Literal(Vec<Constant>),
/// The non-type variable associated with this constant.
///
/// Invoking [Unifier::get_ty] on this type will return a [TypeEnum::TVar] representing the
/// const generic variable of which this constant is associated with.
ty: Type,
/// The constant value of this constant.
value: SymbolValue
},
List(Box<TypeAnnotation>),
Tuple(Vec<TypeAnnotation>), Tuple(Vec<TypeAnnotation>),
} }
@ -34,9 +28,7 @@ impl TypeAnnotation {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => { CustomClass { id, params } => {
let class_name = if let Some(ref top) = unifier.top_level { let class_name = if let Some(ref top) = unifier.top_level {
if let TopLevelDef::Class { name, .. } = if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() {
&*top.definitions.read()[id.0].read()
{
(*name).into() (*name).into()
} else { } else {
unreachable!() unreachable!()
@ -44,24 +36,25 @@ impl TypeAnnotation {
} else { } else {
format!("class_def_{}", id.0) format!("class_def_{}", id.0)
}; };
format!( format!("{}{}", class_name, {
"{}{}", let param_list =
class_name, params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
{
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
if param_list.is_empty() { if param_list.is_empty() {
String::new() String::new()
} else { } else {
format!("[{param_list}]") format!("[{param_list}]")
} }
})
} }
) Literal(values) => {
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
} }
Constant { value, .. } => format!("Const({value})"),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => { Tuple(types) => {
format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")) format!(
"tuple[{}]",
types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")
)
} }
} }
} }
@ -73,19 +66,18 @@ impl TypeAnnotation {
/// generic variables associated with the definition. /// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass /// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally. /// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>, S>,
type_var: Option<Type>,
) -> Result<TypeAnnotation, HashSet<String>> { ) -> Result<TypeAnnotation, HashSet<String>> {
let name_handle = |id: &StrRef, let name_handle = |id: &StrRef,
unifier: &mut Unifier, unifier: &mut Unifier,
locked: HashMap<DefinitionId, Vec<Type>>| { locked: HashMap<DefinitionId, Vec<Type>, S>| {
if id == &"int32".into() { if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32)) Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() { } else if id == &"int64".into() {
@ -101,7 +93,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} else if id == &"str".into() { } else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str)) Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() { } else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Vec::default() }) Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) { } else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = { let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read(); let def_read = top_level_defs[obj_id.0].try_read();
@ -109,12 +101,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if let TopLevelDef::Class { type_vars, .. } = &*def_read { if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone() type_vars.clone()
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"function cannot be used as a type (at {})", "function cannot be used as a type (at {})",
expr.location expr.location
), )]));
]))
} }
} else { } else {
locked.get(&obj_id).unwrap().clone() locked.get(&obj_id).unwrap().clone()
@ -122,29 +112,29 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
// check param number here // check param number here
if !type_vars.is_empty() { if !type_vars.is_empty() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"expect {} type variable parameter but got 0 (at {})", "expect {} type variable parameter but got 0 (at {})",
type_vars.len(), type_vars.len(),
expr.location, expr.location,
), )]));
]))
} }
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0; let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty;
unifier.unify(var, ty).unwrap(); unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty)) Ok(TypeAnnotation::TypeVar(ty))
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!(
format!("`{}` is not a valid type annotation (at {})", id, expr.location), "`{}` is not a valid type annotation (at {})",
])) id, expr.location
)]))
} }
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!(
format!("`{}` is not a valid type annotation (at {})", id, expr.location), "`{}` is not a valid type annotation (at {})",
])) id, expr.location
)]))
} }
}; };
@ -152,12 +142,12 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|id: &StrRef, |id: &StrRef,
slice: &ast::Expr<T>, slice: &ast::Expr<T>,
unifier: &mut Unifier, unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>>| { mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()].contains(id) if ["virtual".into(), "Generic".into(), "tuple".into(), "Option".into()].contains(id) {
{ return Err(HashSet::from([format!(
return Err(HashSet::from([ "keywords cannot be class name (at {})",
format!("keywords cannot be class name (at {})", expr.location), expr.location
])) )]));
} }
let obj_id = resolver.get_identifier_def(*id)?; let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = { let type_vars = {
@ -180,19 +170,16 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
vec![slice] vec![slice]
}; };
if type_vars.len() != params_ast.len() { if type_vars.len() != params_ast.len() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"expect {} type parameters but got {} (at {})", "expect {} type parameters but got {} (at {})",
type_vars.len(), type_vars.len(),
params_ast.len(), params_ast.len(),
params_ast[0].location, params_ast[0].location,
), )]));
]))
} }
let result = params_ast let result = params_ast
.iter() .iter()
.enumerate() .map(|x| {
.map(|(idx, x)| {
parse_ast_to_type_annotation_kinds( parse_ast_to_type_annotation_kinds(
resolver, resolver,
top_level_defs, top_level_defs,
@ -203,7 +190,6 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
locked.insert(obj_id, type_vars.clone()); locked.insert(obj_id, type_vars.clone());
locked.clone() locked.clone()
}, },
Some(type_vars[idx]),
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -218,7 +204,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
"application of type vars to generic class is not currently supported (at {})", "application of type vars to generic class is not currently supported (at {})",
params_ast[0].location params_ast[0].location
), ),
])) ]));
} }
}; };
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
@ -239,7 +225,6 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
if !matches!(def, TypeAnnotation::CustomClass { .. }) { if !matches!(def, TypeAnnotation::CustomClass { .. }) {
unreachable!("must be concretized custom class kind in the virtual") unreachable!("must be concretized custom class kind in the virtual")
@ -247,24 +232,6 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
Ok(TypeAnnotation::Virtual(def.into())) Ok(TypeAnnotation::Virtual(def.into()))
} }
// list
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
None,
)?;
Ok(TypeAnnotation::List(def_ann.into()))
}
// option // option
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
@ -278,7 +245,6 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
slice.as_ref(), slice.as_ref(),
locked, locked,
None,
)?; )?;
let id = let id =
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() {
@ -312,54 +278,76 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
primitives, primitives,
e, e,
locked.clone(), locked.clone(),
None,
) )
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(TypeAnnotation::Tuple(type_annotations)) Ok(TypeAnnotation::Tuple(type_annotations))
} }
// Literal
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} =>
{
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| match &e.node {
ast::ExprKind::Constant { value, .. } => {
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|type_ann| match type_ann {
TypeAnnotation::Literal(values) => values,
_ => unreachable!(),
})
.collect_vec();
if type_annotations.len() == 1 {
Ok(TypeAnnotation::Literal(type_annotations))
} else {
Err(HashSet::from([format!(
"multiple literal bounds are currently unsupported (at {})",
value.location
)]))
}
}
// custom class // custom class
ast::ExprKind::Subscript { value, slice, .. } => { ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked) class_name_handle(id, slice, unifier, locked)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!(
format!("unsupported expression type for class name (at {})", value.location) "unsupported expression type for class name (at {})",
])) value.location
)]))
} }
} }
ast::ExprKind::Constant { value, .. } => { ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
let type_var = type_var.expect("Expect type variable to be present");
let ntv_ty_enum = unifier.get_ty_immutable(type_var); _ => Err(HashSet::from([format!(
let TypeEnum::TVar { range: underlying_ty, .. } = ntv_ty_enum.as_ref() else { "unsupported expression for type annotation (at {})",
unreachable!()
};
let underlying_ty = underlying_ty[0];
let value = SymbolValue::from_constant(value, underlying_ty, primitives, unifier)
.map_err(|err| HashSet::from([err]))?;
if matches!(value, SymbolValue::Str(_) | SymbolValue::Tuple(_) | SymbolValue::OptionSome(_)) {
return Err(HashSet::from([
format!(
"expression {value} is not allowed for constant type annotation (at {})",
expr.location expr.location
), )])),
]))
}
Ok(TypeAnnotation::Constant {
ty: type_var,
value,
})
}
_ => Err(HashSet::from([
format!("unsupported expression for type annotation (at {})", expr.location),
])),
} }
} }
@ -370,7 +358,7 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>> subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
match ann { match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
@ -381,34 +369,34 @@ pub fn get_type_from_type_annotation_kinds(
}; };
if type_vars.len() != params.len() { if type_vars.len() != params.len() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"unexpected number of type parameters: expected {} but got {}", "unexpected number of type parameters: expected {} but got {}",
type_vars.len(), type_vars.len(),
params.len() params.len()
), )]));
]))
} }
let param_ty = params let param_ty = params
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
top_level_defs,
unifier,
x,
subst_list
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let subst = { let subst = {
// check for compatible range // check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check // TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result: HashMap<u32, Type> = HashMap::new(); let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) { for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() { match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => { TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || { p == *tvar || {
@ -417,14 +405,13 @@ pub fn get_type_from_type_annotation_kinds(
*name, *name,
*loc, *loc,
); );
unifier.unify(temp.0, p).is_ok() unifier.unify(temp.ty, p).is_ok()
} }
}; };
if ok { if ok {
result.insert(*id, p); result.insert(*id, p);
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"cannot apply type {} to type variable with id {:?}", "cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify( unifier.internal_stringify(
p, p,
@ -433,8 +420,7 @@ pub fn get_type_from_type_annotation_kinds(
&mut None &mut None
), ),
*id *id
) )]));
]))
} }
} }
@ -443,24 +429,18 @@ pub fn get_type_from_type_annotation_kinds(
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || { p == *tvar || {
let temp = unifier.get_fresh_const_generic_var( let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
ty, unifier.unify(temp.ty, p).is_ok()
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok()
} }
}; };
if ok { if ok {
result.insert(*id, p); result.insert(*id, p);
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"cannot apply type {} to type variable {}", "cannot apply type {} to type variable {}",
unifier.stringify(p), unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()), name.unwrap_or_else(|| format!("typevar{id}").into()),
), )]));
]))
} }
} }
@ -469,6 +449,7 @@ pub fn get_type_from_type_annotation_kinds(
} }
result result
}; };
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods let mut tobj_fields = methods
.iter() .iter()
.map(|(name, ty, _)| { .map(|(name, ty, _)| {
@ -495,14 +476,14 @@ pub fn get_type_from_type_annotation_kinds(
Ok(ty) Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Constant { ty, value, .. } => { TypeAnnotation::Literal(values) => {
let ty_enum = unifier.get_ty(*ty); let values = values
let TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } = &*ty_enum else { .iter()
unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()); .map(SymbolValue::from_constant_inferred)
}; .collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?;
let ty = ntv_underlying_ty[0]; let var = unifier.get_fresh_literal(values, None);
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
Ok(var) Ok(var)
} }
TypeAnnotation::Virtual(ty) => { TypeAnnotation::Virtual(ty) => {
@ -510,19 +491,10 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
ty.as_ref(), ty.as_ref(),
subst_list subst_list,
)?; )?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} }
TypeAnnotation::List(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
ty.as_ref(),
subst_list
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
}
TypeAnnotation::Tuple(tys) => { TypeAnnotation::Tuple(tys) => {
let tys = tys let tys = tys
.iter() .iter()
@ -563,7 +535,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
let mut result: Vec<TypeAnnotation> = Vec::new(); let mut result: Vec<TypeAnnotation> = Vec::new();
match ann { match ann {
TypeAnnotation::TypeVar(..) => result.push(ann.clone()), TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
TypeAnnotation::Virtual(ann) | TypeAnnotation::List(ann) => { TypeAnnotation::Virtual(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref())); result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
} }
TypeAnnotation::CustomClass { params, .. } => { TypeAnnotation::CustomClass { params, .. } => {
@ -576,7 +548,7 @@ pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<Ty
result.extend(get_type_var_contained_in_type_annotation(a)); result.extend(get_type_var_contained_in_type_annotation(a));
} }
} }
TypeAnnotation::Primitive(..) | TypeAnnotation::Constant { .. } => {} TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {}
} }
result result
} }
@ -597,14 +569,14 @@ pub fn check_overload_type_annotation_compatible(
let ( let (
TypeEnum::TVar { id: a, fields: None, .. }, TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. }, TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b) else { ) = (a, b)
else {
unreachable!("must be type var") unreachable!("must be type var")
}; };
a == b a == b
} }
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) (TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) => {
| (TypeAnnotation::List(a), TypeAnnotation::List(b)) => {
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier) check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
} }

View File

@ -1,16 +1,18 @@
use crate::typecheck::typedef::TypeEnum; use crate::toplevel::helper::PrimDef;
use super::type_inferencer::Inferencer; use super::type_inferencer::Inferencer;
use super::typedef::Type; use super::typedef::{Type, TypeEnum};
use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef}; use nac3parser::ast::{
self, Constant, Expr, ExprKind,
Operator::{LShift, RShift},
Stmt, StmtKind, StrRef,
};
use std::{collections::HashSet, iter::once}; use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> { impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> { fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) { if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) {
Err(HashSet::from([ Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)]))
format!("Error at {}: cannot have value none", expr.location),
]))
} else { } else {
Ok(()) Ok(())
} }
@ -22,9 +24,9 @@ impl<'a> Inferencer<'a> {
defined_identifiers: &mut HashSet<StrRef>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
match &pattern.node { match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => Err(HashSet::from([ ExprKind::Name { id, .. } if id == &"none".into() => {
format!("cannot assign to a `none` (at {})", pattern.location), Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
])), }
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) { if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id); defined_identifiers.insert(*id);
@ -44,20 +46,17 @@ impl<'a> Inferencer<'a> {
self.should_have_value(value)?; self.should_have_value(value)?;
self.check_expr(slice, defined_identifiers)?; self.check_expr(slice, defined_identifiers)?;
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"Error at {}: cannot assign to tuple element", "Error at {}: cannot assign to tuple element",
value.location value.location
), )]));
]))
} }
Ok(()) Ok(())
} }
ExprKind::Constant { .. } => { ExprKind::Constant { .. } => Err(HashSet::from([format!(
Err(HashSet::from([ "cannot assign to a constant (at {})",
format!("cannot assign to a constant (at {})", pattern.location), pattern.location
])) )])),
}
_ => self.check_expr(pattern, defined_identifiers), _ => self.check_expr(pattern, defined_identifiers),
} }
} }
@ -69,14 +68,15 @@ impl<'a> Inferencer<'a> {
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. })
return Err(HashSet::from([ && !ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::List.id())
format!( && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables)
{
return Err(HashSet::from([format!(
"expected concrete type at {} but got {}", "expected concrete type at {} but got {}",
expr.location, expr.location,
self.unifier.get_ty(*ty).get_type_name() self.unifier.get_ty(*ty).get_type_name()
) )]));
]))
} }
} }
match &expr.node { match &expr.node {
@ -96,12 +96,10 @@ impl<'a> Inferencer<'a> {
self.defined_identifiers.insert(*id); self.defined_identifiers.insert(*id);
} }
Err(e) => { Err(e) => {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"type error at identifier `{}` ({}) at {}", "type error at identifier `{}` ({}) at {}",
id, e, expr.location id, e, expr.location
) )]))
]))
} }
} }
} }
@ -127,17 +125,13 @@ impl<'a> Inferencer<'a> {
// Check whether a bitwise shift has a negative RHS constant value // Check whether a bitwise shift has a negative RHS constant value
if *op == LShift || *op == RShift { if *op == LShift || *op == RShift {
if let ExprKind::Constant { value, .. } = &right.node { if let ExprKind::Constant { value, .. } = &right.node {
let Constant::Int(rhs_val) = value else { let Constant::Int(rhs_val) = value else { unreachable!() };
unreachable!()
};
if *rhs_val < 0 { if *rhs_val < 0 {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"shift count is negative at {}", "shift count is negative at {}",
right.location right.location
), )]));
]))
} }
} }
} }
@ -208,6 +202,27 @@ impl<'a> Inferencer<'a> {
Ok(()) Ok(())
} }
/// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
///
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
/// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => [
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
]
.iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
}
// check statements for proper identifier def-use and return on all paths // check statements for proper identifier def-use and return on all paths
fn check_stmt( fn check_stmt(
&mut self, &mut self,
@ -302,6 +317,30 @@ impl<'a> Inferencer<'a> {
if let Some(value) = value { if let Some(value) = value {
self.check_expr(value, defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?; self.should_have_value(value)?;
// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
// is freed when the function returns.
if let Some(ret_ty) = value.custom {
// Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually
// inferred and just generates an unconditional assertion
if matches!(
value.node,
ExprKind::Constant { value: Constant::Ellipsis, .. }
) {
return Ok(true);
}
if !self.check_return_value_ty(ret_ty) {
return Err(HashSet::from([
format!(
"return value of type {} must be a primitive or a tuple of primitives at {}",
self.unifier.stringify(ret_ty),
value.location,
),
]));
}
}
} }
Ok(true) Ok(true)
} }
@ -324,7 +363,7 @@ impl<'a> Inferencer<'a> {
let mut ret = false; let mut ret = false;
for stmt in block { for stmt in block {
if ret { if ret {
eprintln!("warning: dead code at {:?}\n", stmt.location); eprintln!("warning: dead code at {}\n", stmt.location);
} }
if self.check_stmt(stmt, defined_identifiers)? { if self.check_stmt(stmt, defined_identifiers)? {
ret = true; ret = true;

View File

@ -1,73 +1,150 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}; };
use itertools::{iproduct, Itertools};
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop}; use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::cmp::max;
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use strum::IntoEnumIterator;
/// The variant of a binary operator.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BinopVariant {
/// The normal variant.
/// For addition, it would be `+`.
Normal,
/// The "Augmented Assigning Operator" variant.
/// For addition, it would be `+=`.
AugAssign,
}
/// A binary operator with its variant.
#[derive(Debug, Clone, Copy)]
pub struct Binop {
/// The base [`Operator`] of this binary operator.
pub base: Operator,
/// The variant of this binary operator.
pub variant: BinopVariant,
}
impl Binop {
/// Make a [`Binop`] of the normal variant from an [`Operator`].
#[must_use] #[must_use]
pub fn binop_name(op: &Operator) -> &'static str { pub fn normal(base: Operator) -> Self {
match op { Binop { base, variant: BinopVariant::Normal }
Operator::Add => "__add__", }
Operator::Sub => "__sub__",
Operator::Div => "__truediv__", /// Make a [`Binop`] of the aug assign variant from an [`Operator`].
Operator::Mod => "__mod__", #[must_use]
Operator::Mult => "__mul__", pub fn aug_assign(base: Operator) -> Self {
Operator::Pow => "__pow__", Binop { base, variant: BinopVariant::AugAssign }
Operator::BitOr => "__or__",
Operator::BitXor => "__xor__",
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__",
Operator::RShift => "__rshift__",
Operator::FloorDiv => "__floordiv__",
Operator::MatMult => "__matmul__",
} }
} }
#[must_use] /// Details about an operator (unary, binary, etc...) in Python
pub fn binop_assign_name(op: &Operator) -> &'static str { #[derive(Debug, Clone, Copy)]
match op { pub struct OpInfo {
Operator::Add => "__iadd__", /// The method name of the binary operator.
Operator::Sub => "__isub__", /// For addition, this would be `__add__`, and `__iadd__` if
Operator::Div => "__itruediv__", /// it is the augmented assigning variant.
Operator::Mod => "__imod__", pub method_name: &'static str,
Operator::Mult => "__imul__", /// The symbol of the binary operator.
Operator::Pow => "__ipow__", /// For addition, this would be `+`, and `+=` if
Operator::BitOr => "__ior__", /// it is the augmented assigning variant.
Operator::BitXor => "__ixor__", pub symbol: &'static str,
Operator::BitAnd => "__iand__",
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
} }
#[must_use] /// Helper macro to conveniently build an [`OpInfo`].
pub fn unaryop_name(op: &Unaryop) -> &'static str { ///
match op { /// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
Unaryop::UAdd => "__pos__", macro_rules! make_op_info {
Unaryop::USub => "__neg__", ($name:expr, $symbol:expr) => {
Unaryop::Not => "__not__", OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
Unaryop::Invert => "__inv__", };
}
} }
#[must_use] pub trait HasOpInfo {
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> { fn op_info(&self) -> OpInfo;
}
fn try_get_cmpop_info(op: Cmpop) -> Option<OpInfo> {
match op { match op {
Cmpop::Lt => Some("__lt__"), Cmpop::Lt => Some(make_op_info!("lt", "<")),
Cmpop::LtE => Some("__le__"), Cmpop::LtE => Some(make_op_info!("le", "<=")),
Cmpop::Gt => Some("__gt__"), Cmpop::Gt => Some(make_op_info!("gt", ">")),
Cmpop::GtE => Some("__ge__"), Cmpop::GtE => Some(make_op_info!("ge", ">=")),
Cmpop::Eq => Some("__eq__"), Cmpop::Eq => Some(make_op_info!("eq", "==")),
Cmpop::NotEq => Some("__ne__"), Cmpop::NotEq => Some(make_op_info!("ne", "!=")),
_ => None, _ => None,
} }
} }
impl OpInfo {
#[must_use]
pub fn supports_cmpop(op: Cmpop) -> bool {
try_get_cmpop_info(op).is_some()
}
}
impl HasOpInfo for Cmpop {
fn op_info(&self) -> OpInfo {
try_get_cmpop_info(*self).expect("{self:?} is not supported")
}
}
impl HasOpInfo for Binop {
fn op_info(&self) -> OpInfo {
// Helper macro to generate both the normal variant [`OpInfo`] and the
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
macro_rules! info {
($name:literal, $symbol:literal) => {
(
make_op_info!($name, $symbol),
make_op_info!(concat!("i", $name), concat!($symbol, "=")),
)
};
}
let (normal_variant, aug_assign_variant) = match self.base {
Operator::Add => info!("add", "+"),
Operator::Sub => info!("sub", "-"),
Operator::Div => info!("truediv", "/"),
Operator::Mod => info!("mod", "%"),
Operator::Mult => info!("mul", "*"),
Operator::Pow => info!("pow", "**"),
Operator::BitOr => info!("or", "|"),
Operator::BitXor => info!("xor", "^"),
Operator::BitAnd => info!("and", "&"),
Operator::LShift => info!("lshift", "<<"),
Operator::RShift => info!("rshift", ">>"),
Operator::FloorDiv => info!("floordiv", "//"),
Operator::MatMult => info!("matmul", "@"),
};
match self.variant {
BinopVariant::Normal => normal_variant,
BinopVariant::AugAssign => aug_assign_variant,
}
}
}
impl HasOpInfo for Unaryop {
fn op_info(&self) -> OpInfo {
match self {
Unaryop::UAdd => make_op_info!("pos", "+"),
Unaryop::USub => make_op_info!("neg", "-"),
Unaryop::Not => make_op_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
Unaryop::Invert => make_op_info!("inv", "~"),
}
}
}
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F) pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
where where
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>), F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
@ -90,40 +167,28 @@ pub fn impl_binop(
_store: &PrimitiveStore, _store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
ops: &[Operator], ops: &[Operator],
) { ) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 { let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None) (other_ty[0], None)
} else { } else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id)) (tvar.ty, Some(tvar.id))
}; };
let function_vars = if let Some(var_id) = other_var_id { let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>() vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else { } else {
HashMap::new() VarMap::new()
}; };
for op in ops { let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
fields.insert(binop_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}],
})),
false,
)
});
fields.insert(binop_assign_name(op).into(), { for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
let op = Binop { base: *base_op, variant };
fields.insert(op.op_info().method_name.into(), {
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -141,15 +206,17 @@ pub fn impl_binop(
}); });
} }
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryop]) { pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for op in ops {
fields.insert( fields.insert(
unaryop_name(op).into(), op.op_info().method_name.into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
vars: HashMap::new(), vars: VarMap::new(),
args: vec![], args: vec![],
})), })),
false, false,
@ -161,19 +228,35 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryo
pub fn impl_cmpop( pub fn impl_cmpop(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, _store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: Type, other_ty: &[Type],
ops: &[Cmpop], ops: &[Cmpop],
ret_ty: Option<Type>,
) { ) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(tvar.ty, Some(tvar.id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for op in ops {
fields.insert( fields.insert(
comparison_name(op).unwrap().into(), op.op_info().method_name.into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.bool, ret: ret_ty,
vars: HashMap::new(), vars: function_vars.clone(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
@ -193,7 +276,7 @@ pub fn impl_basic_arithmetic(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop( impl_binop(
unifier, unifier,
@ -211,7 +294,7 @@ pub fn impl_pow(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]); impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
} }
@ -223,19 +306,32 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
store, store,
ty, ty,
&[ty], &[ty],
ty, Some(ty),
&[Operator::BitAnd, Operator::BitOr, Operator::BitXor], &[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
); );
} }
/// `LShift`, `RShift` /// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]); impl_binop(
unifier,
store,
ty,
&[store.int32, store.uint32],
Some(ty),
&[Operator::LShift, Operator::RShift],
);
} }
/// `Div` /// `Div`
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { pub fn impl_div(
impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]); unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
} }
/// `FloorDiv` /// `FloorDiv`
@ -244,7 +340,7 @@ pub fn impl_floordiv(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]); impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
} }
@ -255,40 +351,323 @@ pub fn impl_mod(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]); impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
} }
/// [`Operator::MatMult`]
pub fn impl_matmul(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]);
}
/// `UAdd`, `USub` /// `UAdd`, `USub`
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]); impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
} }
/// `Invert` /// `Invert`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]); impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
} }
/// `Not` /// `Not`
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]); impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
} }
/// `Lt`, `LtE`, `Gt`, `GtE` /// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { pub fn impl_comparison(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop( impl_cmpop(
unifier, unifier,
store, store,
ty, ty,
other_ty, other_ty,
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE], &[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
ret_ty,
); );
} }
/// `Eq`, `NotEq` /// `Eq`, `NotEq`
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_eq(
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]); unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
}
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
pub fn typeof_ndarray_broadcast(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
left: Type,
right: Type,
) -> Result<Type, String> {
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
assert!(is_left_ndarray || is_right_ndarray);
if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let res_ndims = left_ty_ndims
.into_iter()
.cartesian_product(right_ty_ndims)
.map(|(left, right)| {
let left_val = u64::try_from(left).unwrap();
let right_val = u64::try_from(right).unwrap();
max(left_val, right_val)
})
.unique()
.map(SymbolValue::U64)
.collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty)
} else {
let (expected_ty, actual_ty) = if is_left_ndarray {
(ndarray_ty_dtype, scalar_ty)
} else {
(scalar_ty, ndarray_ty_dtype)
};
Err(format!(
"Expected right-hand side operand to be {}, got {}",
unifier.stringify(expected_ty),
unifier.stringify(actual_ty),
))
}
}
}
/// Returns the return type given a binary operator and its primitive operands.
pub fn typeof_binop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: Operator,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let op = Binop { base: op, variant: BinopVariant::Normal };
let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(match op.base {
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
if is_left_list || is_right_list {
if ![Operator::Add, Operator::Mult].contains(&op.base) {
return Err(format!(
"Binary operator {} not supported for list",
op.op_info().symbol
));
}
if is_left_list {
lhs
} else {
rhs
}
} else if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None);
}
}
Operator::MatMult => {
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
match (lhs_ndims, rhs_ndims) {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
u8::from(rhs == 0)
))
}
(lhs, rhs) => {
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
}
}
Operator::Div => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
primitives.float
} else {
return Ok(None);
}
}
Operator::Pow => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if [
primitives.int32,
primitives.int64,
primitives.uint32,
primitives.uint64,
primitives.float,
]
.into_iter()
.any(|ty| unifier.unioned(lhs, ty))
{
lhs
} else {
return Ok(None);
}
}
Operator::LShift | Operator::RShift => lhs,
Operator::BitOr | Operator::BitXor | Operator::BitAnd => {
if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None);
}
}
}))
}
pub fn typeof_unaryop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: Unaryop,
operand: Type,
) -> Result<Option<Type>, String> {
let operand_obj_id = operand.obj_id(unifier);
if op == Unaryop::Not
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
{
return Err(
"The truth value of an array with more than one element is ambiguous".to_string()
);
}
Ok(match op {
Unaryop::Not => match operand_obj_id {
Some(v) if v == PrimDef::NDArray.id() => Some(operand),
Some(_) => Some(primitives.bool),
_ => None,
},
Unaryop::Invert => {
if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
}
}
Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
});
}
Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
}
}
})
}
/// Returns the return type given a comparison operator and its primitive operands.
pub fn typeof_cmpop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
_op: Cmpop,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
return Ok(None);
}))
} }
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
@ -299,39 +678,77 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t, bool: bool_t,
uint32: uint32_t, uint32: uint32_t,
uint64: uint64_t, uint64: uint64_t,
list: list_t,
ndarray: ndarray_t,
.. ..
} = *store; } = *store;
let size_t = store.usize();
/* int ======== */ /* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] { for t in [int32_t, int64_t, uint32_t, uint64_t] {
impl_basic_arithmetic(unifier, store, t, &[t], t); let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
impl_pow(unifier, store, t, &[t], t); impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t); impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t]); impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t], t); impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t], t); impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t); impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t); impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, t); impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
impl_eq(unifier, store, t); impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
} }
for t in [int32_t, int64_t] { for t in [int32_t, int64_t] {
impl_sign(unifier, store, t); impl_sign(unifier, store, t, Some(t));
} }
/* float ======== */ /* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
impl_div(unifier, store, float_t, &[float_t]); impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t], float_t); impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t], float_t); impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t); impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_not(unifier, store, float_t); impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_comparison(unifier, store, float_t, float_t); impl_sign(unifier, store, float_t, Some(float_t));
impl_eq(unifier, store, float_t); impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
/* bool ======== */ /* bool ======== */
impl_not(unifier, store, bool_t); let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_eq(unifier, store, bool_t); impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
impl_cmpop(unifier, store, list_t, &[list_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* ndarray ===== */
let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(
unifier,
store,
ndarray_t,
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
None,
);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
} }

View File

@ -1,24 +1,46 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum; use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum};
use super::typedef::{RecordKey, Type, Unifier}; use super::{
use nac3parser::ast::{Location, StrRef}; magic_methods::Binop,
typedef::{RecordKey, Type, Unifier},
};
use itertools::Itertools;
use nac3parser::ast::{Cmpop, Location, StrRef};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TypeErrorKind { pub enum TypeErrorKind {
TooManyArguments { GotMultipleValues {
expected: usize, name: StrRef,
got: usize, },
TooManyArguments {
expected_min_count: usize,
expected_max_count: usize,
got_count: usize,
},
MissingArgs {
missing_arg_names: Vec<StrRef>,
}, },
MissingArgs(String),
UnknownArgName(StrRef), UnknownArgName(StrRef),
IncorrectArgType { IncorrectArgType {
name: StrRef, name: StrRef,
expected: Type, expected: Type,
got: Type, got: Type,
}, },
UnsupportedBinaryOpTypes {
operator: Binop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
UnsupportedComparsionOpTypes {
operator: Cmpop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
FieldUnificationError { FieldUnificationError {
field: RecordKey, field: RecordKey,
types: (Type, Type), types: (Type, Type),
@ -34,6 +56,7 @@ pub enum TypeErrorKind {
}, },
RequiresTypeAnn, RequiresTypeAnn,
PolymorphicFunctionPointer, PolymorphicFunctionPointer,
NoSuchAttribute(RecordKey, Type),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -77,22 +100,49 @@ impl<'a> Display for DisplayTypeError<'a> {
use TypeErrorKind::*; use TypeErrorKind::*;
let mut notes = Some(HashMap::new()); let mut notes = Some(HashMap::new());
match &self.err.kind { match &self.err.kind {
TooManyArguments { expected, got } => { GotMultipleValues { name } => {
write!(f, "Too many arguments. Expected {expected} but got {got}") write!(f, "For multiple values for parameter {name}")
} }
MissingArgs(args) => { TooManyArguments { expected_min_count, expected_max_count, got_count } => {
debug_assert!(expected_min_count <= expected_max_count);
if expected_min_count == expected_max_count {
let expected_count = expected_min_count; // or expected_max_count
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
} else {
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
}
}
MissingArgs { missing_arg_names } => {
let args = missing_arg_names.iter().join(", ");
write!(f, "Missing arguments: {args}") write!(f, "Missing arguments: {args}")
} }
UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnknownArgName(name) => { UnknownArgName(name) => {
write!(f, "Unknown argument name: {name}") write!(f, "Unknown argument name: {name}")
} }
IncorrectArgType { name, expected, got } => { IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!( write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
f,
"Incorrect argument type for {name}. Expected {expected}, but got {got}"
)
} }
FieldUnificationError { field, types, loc } => { FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
@ -159,6 +209,10 @@ impl<'a> Display for DisplayTypeError<'a> {
let t = self.unifier.stringify_with_notes(*t, &mut notes); let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` field/method does not exist") write!(f, "`{t}::{name}` field/method does not exist")
} }
NoSuchAttribute(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` is not a class attribute")
}
TupleIndexOutOfBounds { index, len } => { TupleIndexOutOfBounds { index, len } => {
write!( write!(
f, f,

File diff suppressed because it is too large Load Diff

View File

@ -3,12 +3,14 @@ use super::*;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
}; };
use indexmap::IndexMap;
use indoc::indoc; use indoc::indoc;
use std::iter::zip; use nac3parser::ast::FileName;
use nac3parser::parser::parse_program; use nac3parser::parser::parse_program;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::iter::zip;
use test_case::test_case; use test_case::test_case;
struct Resolver { struct Resolver {
@ -32,19 +34,21 @@ impl SymbolResolver for Resolver {
_: &PrimitiveStore, _: &PrimitiveStore,
str: StrRef, str: StrRef,
) -> Result<Type, String> { ) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str)) self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
} }
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, 'a>, _: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def.get(&id).cloned() self.id_to_def
.get(&id)
.copied()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()])) .ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
} }
@ -73,67 +77,81 @@ impl TestEnvironment {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32, ret: int32,
vars: HashMap::new(), vars: VarMap::new(),
})); }));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: PrimDef::Str.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7), obj_id: PrimDef::Exception.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(8), obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(9), obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(10), obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -147,10 +165,14 @@ impl TestEnvironment {
uint32, uint32,
uint64, uint64,
option, option,
list,
ndarray,
size_t: 64,
}; };
unifier.put_primitive_store(&primitives);
set_primitives_magic_methods(&primitives, &mut unifier); set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [ let id_to_name: HashMap<_, _> = [
(0, "int32".into()), (0, "int32".into()),
(1, "int64".into()), (1, "int64".into()),
(2, "float".into()), (2, "float".into()),
@ -160,23 +182,21 @@ impl TestEnvironment {
(6, "str".into()), (6, "str".into()),
(7, "exception".into()), (7, "exception".into()),
] ]
.iter() .into();
.cloned()
.collect();
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver { let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(), id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(), id_to_def: HashMap::default(),
class_names: Default::default(), class_names: HashMap::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment { TestEnvironment {
top_level: TopLevelContext { top_level: TopLevelContext {
definitions: Default::default(), definitions: Arc::default(),
unifiers: Default::default(), unifiers: Arc::default(),
personality_symbol: None, personality_symbol: None,
}, },
unifier, unifier,
@ -198,70 +218,95 @@ impl TestEnvironment {
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new(); let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }],
ret: int32, ret: int32,
vars: HashMap::new(), vars: VarMap::new(),
})); }));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: PrimDef::Str.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7), obj_id: PrimDef::Exception.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(8), obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(9), obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(10), obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::new(),
}); });
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] for (i, name) in [
"int32",
"int64",
"float",
"bool",
"none",
"range",
"str",
"Exception",
"uint32",
"uint64",
"Option",
"list",
"ndarray",
]
.iter() .iter()
.enumerate() .enumerate()
{ {
@ -269,10 +314,11 @@ impl TestEnvironment {
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: (*name).into(), name: (*name).into(),
object_id: DefinitionId(i), object_id: DefinitionId(i),
type_vars: Default::default(), type_vars: Vec::default(),
fields: Default::default(), fields: Vec::default(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -280,7 +326,7 @@ impl TestEnvironment {
.into(), .into(),
); );
} }
let defs = 7; let defs = 12;
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -294,23 +340,29 @@ impl TestEnvironment {
uint32, uint32,
uint64, uint64,
option, option,
list,
ndarray,
size_t: 64,
}; };
let (v0, id) = unifier.get_dummy_var(); unifier.put_primitive_store(&primitives);
let tvar = unifier.get_dummy_var();
let foo_ty = unifier.add_ty(TypeEnum::TObj { let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 1), obj_id: DefinitionId(defs + 1),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(), fields: [("a".into(), (tvar.ty, true))].into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(), params: into_var_map([tvar]),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Foo".into(), name: "Foo".into(),
object_id: DefinitionId(defs + 1), object_id: DefinitionId(defs + 1),
type_vars: vec![v0], type_vars: vec![tvar.ty],
fields: [("a".into(), v0, true)].into(), fields: [("a".into(), tvar.ty, true)].into(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -323,31 +375,29 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: foo_ty, ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(), vars: into_var_map([tvar]),
})), })),
); );
let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature { let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: int32, ret: int32,
vars: Default::default(), vars: IndexMap::default(),
})); }));
let bar = unifier.add_ty(TypeEnum::TObj { let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 2), obj_id: DefinitionId(defs + 2),
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))] fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))].into(),
.iter() params: IndexMap::default(),
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Bar".into(), name: "Bar".into(),
object_id: DefinitionId(defs + 2), object_id: DefinitionId(defs + 2),
type_vars: Default::default(), type_vars: Vec::default(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(), fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -359,26 +409,24 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: bar, ret: bar,
vars: Default::default(), vars: IndexMap::default(),
})), })),
); );
let bar2 = unifier.add_ty(TypeEnum::TObj { let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 3), obj_id: DefinitionId(defs + 3),
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))] fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))].into(),
.iter() params: IndexMap::default(),
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Bar2".into(), name: "Bar2".into(),
object_id: DefinitionId(defs + 3), object_id: DefinitionId(defs + 3),
type_vars: Default::default(), type_vars: Vec::default(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(), fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -390,10 +438,10 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: bar2, ret: bar2,
vars: Default::default(), vars: IndexMap::default(),
})), })),
); );
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into();
let id_to_name = [ let id_to_name = [
"int32".into(), "int32".into(),
@ -404,18 +452,22 @@ impl TestEnvironment {
"range".into(), "range".into(),
"str".into(), "str".into(),
"exception".into(), "exception".into(),
"uint32".into(),
"uint64".into(),
"option".into(),
"list".into(),
"ndarray".into(),
"Foo".into(), "Foo".into(),
"Bar".into(), "Bar".into(),
"Bar2".into(), "Bar2".into(),
] ]
.iter() .into_iter()
.enumerate() .enumerate()
.map(|(a, b)| (a, *b))
.collect(); .collect();
let top_level = TopLevelContext { let top_level = TopLevelContext {
definitions: Arc::new(top_level_defs.into()), definitions: Arc::new(top_level_defs.into()),
unifiers: Default::default(), unifiers: Arc::default(),
personality_symbol: None, personality_symbol: None,
}; };
@ -426,9 +478,7 @@ impl TestEnvironment {
("Bar".into(), DefinitionId(defs + 2)), ("Bar".into(), DefinitionId(defs + 2)),
("Bar2".into(), DefinitionId(defs + 3)), ("Bar2".into(), DefinitionId(defs + 3)),
] ]
.iter() .into(),
.cloned()
.collect(),
class_names, class_names,
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
@ -453,11 +503,11 @@ impl TestEnvironment {
top_level: &self.top_level, top_level: &self.top_level,
function_data: &mut self.function_data, function_data: &mut self.function_data,
unifier: &mut self.unifier, unifier: &mut self.unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &mut self.primitives, primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks, virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls, calls: &mut self.calls,
defined_identifiers: Default::default(), defined_identifiers: HashSet::default(),
in_handler: false, in_handler: false,
} }
} }
@ -469,7 +519,7 @@ impl TestEnvironment {
c = 1.234 c = 1.234
d = True d = True
"}, "},
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(), &[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].into(),
&[] &[]
; "primitives test")] ; "primitives test")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -478,7 +528,7 @@ impl TestEnvironment {
c = 1.234 c = 1.234
d = b(c) d = b(c)
"}, "},
[("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(), &[("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].into(),
&[] &[]
; "lambda test")] ; "lambda test")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -487,7 +537,7 @@ impl TestEnvironment {
a = b a = b
c = b(1) c = b(1)
"}, "},
[("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].iter().cloned().collect(), &[("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].into(),
&[] &[]
; "lambda test 2")] ; "lambda test 2")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -503,15 +553,15 @@ impl TestEnvironment {
b(123) b(123)
"}, "},
[("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"), &[("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(), ("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].into(),
&[] &[]
; "obj test")] ; "obj test")]
#[test_case(indoc! {" #[test_case(indoc! {"
a = [1, 2, 3] a = [1, 2, 3]
b = [x + x for x in a] b = [x + x for x in a]
"}, "},
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(), &[("a", "list[int32]"), ("b", "list[int32]")].into(),
&[] &[]
; "listcomp test")] ; "listcomp test")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -519,25 +569,25 @@ impl TestEnvironment {
b = a.b() b = a.b()
a = virtual(Bar2()) a = virtual(Bar2())
"}, "},
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(), &[("a", "virtual[Bar]"), ("b", "int32")].into(),
&[("Bar", "Bar"), ("Bar2", "Bar")] &[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")] ; "virtual test")]
#[test_case(indoc! {" #[test_case(indoc! {"
a = [virtual(Bar(), Bar), virtual(Bar2())] a = [virtual(Bar(), Bar), virtual(Bar2())]
b = [x.b() for x in a] b = [x.b() for x in a]
"}, "},
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(), &[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].into(),
&[("Bar", "Bar"), ("Bar2", "Bar")] &[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual list test")] ; "virtual list test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) { fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{}", source); println!("source:\n{source}");
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
defined_identifiers.insert("virtual".into()); defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone(); inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, Default::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let statements = statements let statements = statements
.into_iter() .into_iter()
.map(|v| inferencer.fold_stmt(v)) .map(|v| inferencer.fold_stmt(v))
@ -546,37 +596,37 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() { for (k, v) in &inferencer.variable_mapping {
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*v, *v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
println!("{}: {}", k, name); println!("{k}: {name}");
} }
for (k, v) in mapping.iter() { for (k, v) in mapping {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*ty, *ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{k}: {v}"), format!("{k}: {name}"));
} }
assert_eq!(inferencer.virtual_checks.len(), virtuals.len()); assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b, _), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) { for ((a, b, _), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.internal_stringify( let a = inferencer.unifier.internal_stringify(
*a, *a,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
let b = inferencer.unifier.internal_stringify( let b = inferencer.unifier.internal_stringify(
*b, *b,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
@ -595,14 +645,14 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
g = a // b g = a // b
h = a % b h = a % b
"}, "},
[("a", "int32"), &[("a", "int32"),
("b", "int32"), ("b", "int32"),
("c", "int32"), ("c", "int32"),
("d", "int32"), ("d", "int32"),
("e", "int32"), ("e", "int32"),
("f", "float"), ("f", "float"),
("g", "int32"), ("g", "int32"),
("h", "int32")].iter().cloned().collect() ("h", "int32")].into()
; "int32")] ; "int32")]
#[test_case( #[test_case(
indoc! {" indoc! {"
@ -618,7 +668,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
ii = 3 ii = 3
j = a ** b j = a ** b
"}, "},
[("a", "float"), &[("a", "float"),
("b", "float"), ("b", "float"),
("c", "float"), ("c", "float"),
("d", "float"), ("d", "float"),
@ -628,7 +678,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
("h", "float"), ("h", "float"),
("i", "float"), ("i", "float"),
("ii", "int32"), ("ii", "int32"),
("j", "float")].iter().cloned().collect() ("j", "float")].into()
; "float" ; "float"
)] )]
#[test_case( #[test_case(
@ -646,7 +696,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
k = a < b k = a < b
l = a != b l = a != b
"}, "},
[("a", "int64"), &[("a", "int64"),
("b", "int64"), ("b", "int64"),
("c", "int64"), ("c", "int64"),
("d", "int64"), ("d", "int64"),
@ -657,7 +707,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
("i", "bool"), ("i", "bool"),
("j", "bool"), ("j", "bool"),
("k", "bool"), ("k", "bool"),
("l", "bool")].iter().cloned().collect() ("l", "bool")].into()
; "int64" ; "int64"
)] )]
#[test_case( #[test_case(
@ -668,22 +718,22 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
d = not a d = not a
e = a != b e = a != b
"}, "},
[("a", "bool"), &[("a", "bool"),
("b", "bool"), ("b", "bool"),
("c", "bool"), ("c", "bool"),
("d", "bool"), ("d", "bool"),
("e", "bool")].iter().cloned().collect() ("e", "bool")].into()
; "boolean" ; "boolean"
)] )]
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) {
println!("source:\n{}", source); println!("source:\n{source}");
let mut env = TestEnvironment::basic_test_env(); let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().copied().collect();
defined_identifiers.insert("virtual".into()); defined_identifiers.insert("virtual".into());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone(); inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, Default::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let statements = statements let statements = statements
.into_iter() .into_iter()
.map(|v| inferencer.fold_stmt(v)) .map(|v| inferencer.fold_stmt(v))
@ -692,23 +742,23 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() { for (k, v) in &inferencer.variable_mapping {
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*v, *v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
println!("{}: {}", k, name); println!("{k}: {name}");
} }
for (k, v) in mapping.iter() { for (k, v) in mapping {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*ty, *ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{k}: {v}"), format!("{k}: {name}"));
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -32,28 +32,25 @@ impl Unifier {
ty1.len() == ty2.len() ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
} }
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 }) (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => self.eq(*ty1, *ty2),
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
self.eq(*ty1, *ty2)
}
( (
TypeEnum::TObj { obj_id: id1, params: params1, .. }, TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. }, TypeEnum::TObj { obj_id: id2, params: params2, .. },
) => id1 == id2 && self.map_eq(params1, params2), ) => id1 == id2 && self.map_eq(params1, params2),
// TCall and TFunc are not yet implemented // TLiteral, TCall and TFunc are not yet implemented
_ => false, _ => false,
} }
} }
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where where
K: std::hash::Hash + Eq + Clone, K: std::hash::Hash + Eq + Clone,
{ {
if map1.len() != map2.len() { if map1.len() != map2.len() {
return false; return false;
} }
for (k, v) in map1.iter() { for (k, v) in map1 {
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) { if !map2.get(k).is_some_and(|v1| self.eq(*v, *v1)) {
return false; return false;
} }
} }
@ -67,8 +64,8 @@ impl Unifier {
if map1.len() != map2.len() { if map1.len() != map2.len() {
return false; return false;
} }
for (k, v) in map1.iter() { for (k, v) in map1 {
if !map2.get(k).map(|v1| self.eq(v.ty, v1.ty)).unwrap_or(false) { if !map2.get(k).is_some_and(|v1| self.eq(v.ty, v1.ty)) {
return false; return false;
} }
} }
@ -91,7 +88,7 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: DefinitionId(0),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}), }),
); );
type_mapping.insert( type_mapping.insert(
@ -99,7 +96,7 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: DefinitionId(1),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}), }),
); );
type_mapping.insert( type_mapping.insert(
@ -107,16 +104,25 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: DefinitionId(2),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}), }),
); );
let (v0, id) = unifier.get_dummy_var(); let tvar = unifier.get_dummy_var();
type_mapping.insert( type_mapping.insert(
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: DefinitionId(3),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(), fields: [("a".into(), (tvar.ty, true))].into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(), params: into_var_map([tvar]),
}),
);
let tvar = unifier.get_dummy_var();
type_mapping.insert(
"list".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([tvar]),
}), }),
); );
@ -129,14 +135,40 @@ impl TestEnvironment {
result.0 result.0
} }
fn internal_parse<'a, 'b>( fn internal_parse<'b>(&mut self, typ: &'b str, mapping: &Mapping<String>) -> (Type, &'b str) {
&'a mut self,
typ: &'b str,
mapping: &Mapping<String>,
) -> (Type, &'b str) {
// for testing only, so we can just panic when the input is malformed // for testing only, so we can just panic when the input is malformed
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len()); let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or(typ.len());
match &typ[..end] { match &typ[..end] {
"list" => {
let mut s = &typ[end..];
assert_eq!(&s[0..1], "[");
let mut ty = Vec::new();
while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0);
s = result.1;
}
assert_eq!(ty.len(), 1);
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*self.unifier.get_ty_immutable(self.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
(
self.unifier
.subst(
self.type_mapping["list"],
&into_var_map([TypeVar { id: list_elem_tvar.id, ty: ty[0] }]),
)
.unwrap(),
&s[1..],
)
}
"tuple" => { "tuple" => {
let mut s = &typ[end..]; let mut s = &typ[end..];
assert_eq!(&s[0..1], "["); assert_eq!(&s[0..1], "[");
@ -148,12 +180,6 @@ impl TestEnvironment {
} }
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..]) (self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
} }
"list" => {
assert_eq!(&typ[end..end + 1], "[");
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping);
assert_eq!(&s[0..1], "]");
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..])
}
"Record" => { "Record" => {
let mut s = &typ[end..]; let mut s = &typ[end..];
assert_eq!(&s[0..1], "["); assert_eq!(&s[0..1], "[");
@ -169,12 +195,12 @@ impl TestEnvironment {
} }
x => { x => {
let mut s = &typ[end..]; let mut s = &typ[end..];
let ty = mapping.get(x).cloned().unwrap_or_else(|| { let ty = mapping.get(x).copied().unwrap_or_else(|| {
// mapping should be type variables, type_mapping should be concrete types // mapping should be type variables, type_mapping should be concrete types
// we should not resolve the type of type variables. // we should not resolve the type of type variables.
let mut ty = *self.type_mapping.get(x).unwrap(); let mut ty = *self.type_mapping.get(x).unwrap();
let te = self.unifier.get_ty(ty); let te = self.unifier.get_ty(ty);
if let TypeEnum::TObj { params, .. } = &*te.as_ref() { if let TypeEnum::TObj { params, .. } = &*te {
if !params.is_empty() { if !params.is_empty() {
assert_eq!(&s[0..1], "["); assert_eq!(&s[0..1], "[");
let mut p = Vec::new(); let mut p = Vec::new();
@ -186,7 +212,7 @@ impl TestEnvironment {
s = &s[1..]; s = &s[1..];
ty = self ty = self
.unifier .unifier
.subst(ty, &params.keys().cloned().zip(p.into_iter()).collect()) .subst(ty, &params.keys().copied().zip(p).collect())
.unwrap_or(ty); .unwrap_or(ty);
} }
} }
@ -250,12 +276,12 @@ fn test_unify(
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.unifier.get_dummy_var(); let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0); mapping.insert(format!("v{i}"), v.ty);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.
let mut pairs = Vec::new(); let mut pairs = Vec::new();
for (a, b) in perm.iter() { for (a, b) in &perm {
let t1 = env.parse(a, &mapping); let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping); let t2 = env.parse(b, &mapping);
pairs.push((t1, t2)); pairs.push((t1, t2));
@ -263,8 +289,8 @@ fn test_unify(
for (t1, t2) in pairs { for (t1, t2) in pairs {
env.unifier.unify(t1, t2).unwrap(); env.unifier.unify(t1, t2).unwrap();
} }
for (a, b) in verify_pairs.iter() { for (a, b) in verify_pairs {
println!("{} = {}", a, b); println!("{a} = {b}");
let t1 = env.parse(a, &mapping); let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping); let t2 = env.parse(b, &mapping);
println!("a = {}, b = {}", env.unifier.stringify(t1), env.unifier.stringify(t2)); println!("a = {}, b = {}", env.unifier.stringify(t1), env.unifier.stringify(t2));
@ -278,7 +304,7 @@ fn test_unify(
("v1", "tuple[int]"), ("v1", "tuple[int]"),
("v2", "list[int]"), ("v2", "list[int]"),
], ],
(("v1", "v2"), "Incompatible types: list[0] and tuple[0]") (("v1", "v2"), "Incompatible types: 11[0] and tuple[0]")
; "type mismatch" ; "type mismatch"
)] )]
#[test_case(2, #[test_case(2,
@ -302,7 +328,7 @@ fn test_unify(
("v1", "Record[a=float,b=int]"), ("v1", "Record[a=float,b=int]"),
("v2", "Foo[v3]"), ("v2", "Foo[v3]"),
], ],
(("v1", "v2"), "`3[typevar4]::b` field/method does not exist") (("v1", "v2"), "`3[typevar5]::b` field/method does not exist")
; "record obj merge" ; "record obj merge"
)] )]
/// Test cases for invalid unifications. /// Test cases for invalid unifications.
@ -315,12 +341,12 @@ fn test_invalid_unification(
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.unifier.get_dummy_var(); let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0); mapping.insert(format!("v{i}"), v.ty);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.
let mut pairs = Vec::new(); let mut pairs = Vec::new();
for (a, b) in unify_pairs.iter() { for (a, b) in unify_pairs {
let t1 = env.parse(a, &mapping); let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping); let t2 = env.parse(b, &mapping);
pairs.push((t1, t2)); pairs.push((t1, t2));
@ -342,16 +368,12 @@ fn test_recursive_subst() {
with_fields(&mut env.unifier, foo_id, |_unifier, fields| { with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
fields.insert("rec".into(), (foo_id, true)); fields.insert("rec".into(), (foo_id, true));
}); });
let TypeEnum::TObj { params, .. } = &*foo_ty else { let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() };
unreachable!()
};
let mapping = params.iter().map(|(id, _)| (*id, int)).collect(); let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap(); let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated); let instantiated_ty = env.unifier.get_ty(instantiated);
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() };
unreachable!()
};
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
} }
@ -363,36 +385,27 @@ fn test_virtual() {
let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature { let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: int, ret: int,
vars: HashMap::new(), vars: VarMap::new(),
})); }));
let bar = env.unifier.add_ty(TypeEnum::TObj { let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))] fields: [("f".into(), (fun, false)), ("a".into(), (int, false))].into(),
.iter() params: VarMap::new(),
.cloned()
.collect::<HashMap<StrRef, _>>(),
params: HashMap::new(),
}); });
let v0 = env.unifier.get_dummy_var().0; let v0 = env.unifier.get_dummy_var().ty;
let v1 = env.unifier.get_dummy_var().0; let v1 = env.unifier.get_dummy_var().ty;
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env let c = env.unifier.add_record([("f".into(), RecordField::new(v1, false, None))].into());
.unifier
.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect());
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap(); env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun)); assert!(env.unifier.eq(v1, fun));
let d = env let d = env.unifier.add_record([("a".into(), RecordField::new(v1, true, None))].into());
.unifier
.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field/method does not exist".to_string())); assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field/method does not exist".to_string()));
let d = env let d = env.unifier.add_record([("b".into(), RecordField::new(v1, true, None))].into());
.unifier
.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field/method does not exist".to_string())); assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field/method does not exist".to_string()));
} }
@ -405,86 +418,132 @@ fn test_typevar_range() {
let int_list = env.parse("list[int]", &HashMap::new()); let int_list = env.parse("list[int]", &HashMap::new());
let float_list = env.parse("list[float]", &HashMap::new()); let float_list = env.parse("list[float]", &HashMap::new());
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
// unification between v and int // unification between v and int
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
env.unifier.unify(int, v).unwrap(); env.unifier.unify(int, v).unwrap();
// unification between v and list[int] // unification between v and list[int]
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
assert_eq!( assert_eq!(
env.unify(int_list, v), env.unify(int_list, v),
Err("Expected any one of these types: 0, 2, but got list[0]".to_string()) Err("Expected any one of these types: 0, 2, but got 11[0]".to_string())
); );
// unification between v and float // unification between v and float
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
assert_eq!( assert_eq!(
env.unify(float, v), env.unify(float, v),
Err("Expected any one of these types: 0, 2, but got 1".to_string()) Err("Expected any one of these types: 0, 2, but got 1".to_string())
); );
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 }); let v1_list = env.unifier.add_ty(TypeEnum::TObj {
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: v1 }]),
});
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and int // unification between v and int
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int, v).unwrap(); env.unifier.unify(int, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and list[int] // unification between v and list[int]
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int_list, v).unwrap(); env.unifier.unify(int_list, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and list[float] // unification between v and list[float]
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
println!("float_list: {}, v: {}", env.unifier.stringify(float_list), env.unifier.stringify(v));
assert_eq!( assert_eq!(
env.unify(float_list, v), env.unify(float_list, v),
Err("Expected any one of these types: 0, list[typevar5], but got list[1]\n\nNotes:\n typevar5 ∈ {0, 2}".to_string()) Err("Expected any one of these types: 0, 11[typevar6], but got 11[1]\n\nNotes:\n typevar6 ∈ {0, 2}".to_string())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
env.unifier.unify(a, float).unwrap(); env.unifier.unify(a, float).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into())); assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into()));
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TObj {
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); fields: Mapping::default(),
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0; params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).ty;
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float }); let float_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: float }]),
});
env.unifier.unify(a_list, float_list).unwrap(); env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b // previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap(); env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TObj {
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int }); let int_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: int }]),
});
assert_eq!( assert_eq!(
env.unify(a_list, int_list), env.unify(a_list, int_list),
Err("Incompatible types: list[typevar22] and list[0]\ Err("Incompatible types: 11[typevar23] and 11[0]\
\n\nNotes:\n typevar22 {1}".into()) \n\nNotes:\n typevar23 {1}"
.into())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_dummy_var().0; let b = env.unifier.get_dummy_var().ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TObj {
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
assert_eq!( assert_eq!(
env.unify(b, boolean), env.unify(b, boolean),
@ -495,17 +554,29 @@ fn test_typevar_range() {
#[test] #[test]
fn test_rigid_var() { fn test_rigid_var() {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let a = env.unifier.get_fresh_rigid_var(None, None).0; let a = env.unifier.get_fresh_rigid_var(None, None).ty;
let b = env.unifier.get_fresh_rigid_var(None, None).0; let b = env.unifier.get_fresh_rigid_var(None, None).ty;
let x = env.unifier.get_dummy_var().0; let x = env.unifier.get_dummy_var().ty;
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a }); let list_elem_tvar = env.unifier.get_fresh_var(Some("list_elem".into()), None);
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x }); let list_a = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let list_x = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: x }]),
});
let int = env.parse("int", &HashMap::new()); let int = env.parse("int", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new()); let list_int = env.parse("list[int]", &HashMap::new());
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string())); assert_eq!(env.unify(a, b), Err("Incompatible types: typevar4 and typevar3".to_string()));
env.unifier.unify(list_a, list_x).unwrap(); env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: list[typevar2] and list[0]".to_string())); assert_eq!(
env.unify(list_x, list_int),
Err("Incompatible types: 11[typevar3] and 11[0]".to_string())
);
env.unifier.replace_rigid_var(a, int); env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap(); env.unifier.unify(list_x, list_int).unwrap();
@ -519,16 +590,26 @@ fn test_instantiation() {
let float = env.parse("float", &HashMap::new()); let float = env.parse("float", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new()); let list_int = env.parse("list[int]", &HashMap::new());
let obj_map: HashMap<_, _> = let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect(); &*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let obj_map: HashMap<_, _> = [(0usize, "int"), (1, "float"), (2, "bool"), (11, "list")].into();
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0; let list_v = env
let t = env.unifier.get_dummy_var().0; .unifier
.subst(env.type_mapping["list"], &into_var_map([TypeVar { id: list_elem_tvar.id, ty: v }]))
.unwrap();
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0; let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t') // t = TypeVar('t')
// v = TypeVar('v', int, bool) // v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int) // v1 = TypeVar('v1', 'list[v]', int)
@ -550,7 +631,7 @@ fn test_instantiation() {
tuple[int, list[bool], list[int]] tuple[int, list[bool], list[int]]
tuple[int, list[int], float] tuple[int, list[int], float]
tuple[int, list[int], list[int]] tuple[int, list[int], list[int]]
v5" v6"
} }
.split('\n') .split('\n')
.collect_vec(); .collect_vec();
@ -559,8 +640,8 @@ fn test_instantiation() {
.map(|ty| { .map(|ty| {
env.unifier.internal_stringify( env.unifier.internal_stringify(
*ty, *ty,
&mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| (*obj_map.get(&i).unwrap()).to_string(),
&mut |i| format!("v{}", i), &mut |i| format!("v{i}"),
&mut None, &mut None,
) )
}) })

View File

@ -16,21 +16,10 @@ pub struct UnificationTable<V> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum Action<V> { enum Action<V> {
Parent { Parent { key: usize, original_parent: usize },
key: usize, Value { key: usize, original_value: Option<V> },
original_parent: usize, Rank { key: usize, original_rank: u32 },
}, Marker { generation: u32 },
Value {
key: usize,
original_value: Option<V>,
},
Rank {
key: usize,
original_rank: u32,
},
Marker {
generation: u32,
}
} }
impl<V> Default for UnificationTable<V> { impl<V> Default for UnificationTable<V> {
@ -41,7 +30,13 @@ impl<V> Default for UnificationTable<V> {
impl<V> UnificationTable<V> { impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> { pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } UnificationTable {
parents: Vec::new(),
ranks: Vec::new(),
values: Vec::new(),
log: Vec::new(),
generation: 0,
}
} }
pub fn new_key(&mut self, v: V) -> UnificationKey { pub fn new_key(&mut self, v: V) -> UnificationKey {
@ -125,7 +120,10 @@ impl<V> UnificationTable<V> {
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) { pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot; let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot restoration error"); assert!(self.log.len() >= log_len, "snapshot restoration error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error"); assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot restoration error"
);
for action in self.log.drain(log_len - 1..).rev() { for action in self.log.drain(log_len - 1..).rev() {
match action { match action {
Action::Parent { key, original_parent } => { Action::Parent { key, original_parent } => {
@ -145,7 +143,10 @@ impl<V> UnificationTable<V> {
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) { pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot; let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error"); assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error"); assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot discard error"
);
self.log.clear(); self.log.clear();
} }
} }
@ -159,11 +160,23 @@ where
.enumerate() .enumerate()
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
.collect(); .collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 } UnificationTable {
parents: self.parents.clone(),
ranks: self.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
} }
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> { pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 } UnificationTable {
parents: table.parents.clone(),
ranks: table.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
} }
} }

View File

@ -32,7 +32,6 @@ pub struct DwarfReader<'a> {
} }
impl<'a> DwarfReader<'a> { impl<'a> DwarfReader<'a> {
pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader { pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader {
DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr } DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr }
} }
@ -60,7 +59,7 @@ impl<'a> DwarfReader<'a> {
let mut byte: u8; let mut byte: u8;
loop { loop {
byte = self.read_u8(); byte = self.read_u8();
result |= ((byte & 0x7F) as u64) << shift; result |= u64::from(byte & 0x7F) << shift;
shift += 7; shift += 7;
if byte & 0x80 == 0 { if byte & 0x80 == 0 {
break; break;
@ -75,7 +74,7 @@ impl<'a> DwarfReader<'a> {
let mut byte: u8; let mut byte: u8;
loop { loop {
byte = self.read_u8(); byte = self.read_u8();
result |= ((byte & 0x7F) as u64) << shift; result |= u64::from(byte & 0x7F) << shift;
shift += 7; shift += 7;
if byte & 0x80 == 0 { if byte & 0x80 == 0 {
break; break;
@ -157,10 +156,9 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize,
} }
match encoding & 0x0F { match encoding & 0x0F {
DW_EH_PE_absptr => Ok(reader.read_u32() as usize), DW_EH_PE_absptr | DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize), DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize),
DW_EH_PE_udata2 => Ok(reader.read_u16() as usize), DW_EH_PE_udata2 => Ok(reader.read_u16() as usize),
DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_udata8 => Ok(reader.read_u64() as usize), DW_EH_PE_udata8 => Ok(reader.read_u64() as usize),
DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize), DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize),
DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize), DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize),
@ -170,10 +168,7 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize,
} }
} }
fn read_encoded_pointer_with_pc( fn read_encoded_pointer_with_pc(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
reader: &mut DwarfReader,
encoding: u8,
) -> Result<usize, ()> {
let entry_virt_addr = reader.virt_addr; let entry_virt_addr = reader.virt_addr;
let mut result = read_encoded_pointer(reader, encoding)?; let mut result = read_encoded_pointer(reader, encoding)?;
@ -223,11 +218,10 @@ pub struct EH_Frame<'a> {
} }
impl<'a> EH_Frame<'a> { impl<'a> EH_Frame<'a> {
/// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF /// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF
/// file. /// file.
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> Result<EH_Frame, ()> { pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> EH_Frame {
Ok(EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) }) EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) }
} }
/// Returns an [Iterator] over all Call Frame Information (CFI) records. /// Returns an [Iterator] over all Call Frame Information (CFI) records.
@ -235,10 +229,7 @@ impl<'a> EH_Frame<'a> {
let reader = DwarfReader::from_reader(&self.reader, true); let reader = DwarfReader::from_reader(&self.reader, true);
let len = reader.slice.len(); let len = reader.slice.len();
CFI_Records { CFI_Records { reader, available: len }
reader,
available: len,
}
} }
} }
@ -255,7 +246,6 @@ pub struct CFI_Record<'a> {
} }
impl<'a> CFI_Record<'a> { impl<'a> CFI_Record<'a> {
pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> { pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> {
let length = cie_reader.read_u32(); let length = cie_reader.read_u32();
let fde_reader = match length { let fde_reader = match length {
@ -264,7 +254,7 @@ impl<'a> CFI_Record<'a> {
// length == u32::MAX means that the length is only representable with 64 bits, // length == u32::MAX means that the length is only representable with 64 bits,
// which does not make sense in a system with 32-bit address. // which does not make sense in a system with 32-bit address.
0xFFFFFFFF => unimplemented!(), 0xFFFF_FFFF => unimplemented!(),
_ => { _ => {
let mut fde_reader = DwarfReader::from_reader(cie_reader, false); let mut fde_reader = DwarfReader::from_reader(cie_reader, false);
@ -323,10 +313,7 @@ impl<'a> CFI_Record<'a> {
} }
assert_ne!(fde_pointer_encoding, DW_EH_PE_omit); assert_ne!(fde_pointer_encoding, DW_EH_PE_omit);
Ok(CFI_Record { Ok(CFI_Record { fde_pointer_encoding, fde_reader })
fde_pointer_encoding,
fde_reader,
})
} }
/// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI /// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI
@ -340,11 +327,7 @@ impl<'a> CFI_Record<'a> {
let reader = self.get_fde_reader(); let reader = self.get_fde_reader();
let len = reader.slice.len(); let len = reader.slice.len();
FDE_Records { FDE_Records { pointer_encoding: self.fde_pointer_encoding, reader, available: len }
pointer_encoding: self.fde_pointer_encoding,
reader,
available: len,
}
} }
} }
@ -371,7 +354,7 @@ impl<'a> Iterator for CFI_Records<'a> {
let length = match length { let length = match length {
// eh_frame with 0-length means the CIE is terminated // eh_frame with 0-length means the CIE is terminated
0 => return None, 0 => return None,
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"), 0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other, other => other,
} as usize; } as usize;
@ -387,7 +370,7 @@ impl<'a> Iterator for CFI_Records<'a> {
// Skip this record if it is a FDE // Skip this record if it is a FDE
if cie_ptr == 0 { if cie_ptr == 0 {
// Rewind back to the start of the CFI Record // Rewind back to the start of the CFI Record
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap()) return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap());
} }
} }
} }
@ -417,7 +400,7 @@ impl<'a> Iterator for FDE_Records<'a> {
let length = match self.reader.read_u32() { let length = match self.reader.read_u32() {
// eh_frame with 0-length means the CIE is terminated // eh_frame with 0-length means the CIE is terminated
0 => return None, 0 => return None,
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"), 0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other, other => other,
} as usize; } as usize;
@ -448,7 +431,6 @@ pub struct EH_Frame_Hdr<'a> {
} }
impl<'a> EH_Frame_Hdr<'a> { impl<'a> EH_Frame_Hdr<'a> {
/// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory. /// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory.
/// ///
/// Load address is not known at this point. /// Load address is not known at this point.
@ -464,8 +446,9 @@ impl<'a> EH_Frame_Hdr<'a> {
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
let eh_frame_offset = eh_frame_addr let eh_frame_offset = eh_frame_addr.wrapping_sub(
.wrapping_sub(eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4)); eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4),
);
writer.write_u32(eh_frame_offset); // eh_frame_ptr writer.write_u32(eh_frame_offset); // eh_frame_ptr
writer.write_u32(0); // `fde_count`, will be written in finalize_fde writer.write_u32(0); // `fde_count`, will be written in finalize_fde
@ -492,7 +475,10 @@ impl<'a> EH_Frame_Hdr<'a> {
self.fde_writer.write_u32(*init_loc); self.fde_writer.write_u32(*init_loc);
self.fde_writer.write_u32(*addr); self.fde_writer.write_u32(*addr);
} }
LittleEndian::write_u32(&mut self.fde_writer.slice[Self::fde_count_offset()..], self.fdes.len() as u32); LittleEndian::write_u32(
&mut self.fde_writer.slice[Self::fde_count_offset()..],
self.fdes.len() as u32,
);
} }
pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize { pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize {
@ -504,7 +490,7 @@ impl<'a> EH_Frame_Hdr<'a> {
// The original length field should be able to hold the entire value. // The original length field should be able to hold the entire value.
// The device memory space is limited to 32-bits addresses anyway. // The device memory space is limited to 32-bits addresses anyway.
let entry_length = reader.read_u32(); let entry_length = reader.read_u32();
if entry_length == 0 || entry_length == 0xFFFFFFFF { if entry_length == 0 || entry_length == 0xFFFF_FFFF {
unimplemented!() unimplemented!()
} }
@ -515,7 +501,7 @@ impl<'a> EH_Frame_Hdr<'a> {
fde_count += 1; fde_count += 1;
} }
reader.offset(entry_length - mem::size_of::<u32>() as u32) reader.offset(entry_length - mem::size_of::<u32>() as u32);
} }
12 + fde_count * 8 12 + fde_count * 8

View File

@ -1,5 +1,5 @@
/* generated from elf.h with rust-bindgen and then manually altered */ /* generated from elf.h with rust-bindgen and then manually altered */
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code)] #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, clippy::pedantic)]
pub const EI_NIDENT: usize = 16; pub const EI_NIDENT: usize = 16;
pub const EI_MAG0: usize = 0; pub const EI_MAG0: usize = 0;

View File

@ -1,3 +1,26 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::struct_field_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
use dwarf::*; use dwarf::*;
use elf::*; use elf::*;
use std::collections::HashMap; use std::collections::HashMap;
@ -70,45 +93,45 @@ struct SectionRecord<'a> {
data: Vec<u8>, data: Vec<u8>,
} }
fn read_unaligned<T: Copy>(data: &[u8], offset: usize) -> Result<T, ()> { fn read_unaligned<T: Copy>(data: &[u8], offset: usize) -> Option<T> {
if data.len() < offset + mem::size_of::<T>() { if data.len() < offset + mem::size_of::<T>() {
Err(()) None
} else { } else {
let ptr = data.as_ptr().wrapping_add(offset) as *const T; let ptr = data.as_ptr().wrapping_add(offset).cast();
Ok(unsafe { ptr::read_unaligned(ptr) }) Some(unsafe { ptr::read_unaligned(ptr) })
} }
} }
pub fn get_ref_slice<T: Copy>(data: &[u8], offset: usize, len: usize) -> Result<&[T], ()> { #[must_use]
pub fn get_ref_slice<T: Copy>(data: &[u8], offset: usize, len: usize) -> Option<&[T]> {
if data.len() < offset + mem::size_of::<T>() * len { if data.len() < offset + mem::size_of::<T>() * len {
Err(()) None
} else { } else {
let ptr = data.as_ptr().wrapping_add(offset) as *const T; let ptr = data.as_ptr().wrapping_add(offset).cast();
Ok(unsafe { slice::from_raw_parts(ptr, len) }) Some(unsafe { slice::from_raw_parts(ptr, len) })
} }
} }
fn from_struct_vec<T>(struct_vec: Vec<T>) -> Vec<u8> { fn from_struct_slice<T>(struct_vec: &[T]) -> Vec<u8> {
let ptr = struct_vec.as_ptr(); let ptr = struct_vec.as_ptr();
unsafe { slice::from_raw_parts(ptr as *const u8, struct_vec.len() * mem::size_of::<T>()) } unsafe { slice::from_raw_parts(ptr.cast(), mem::size_of_val(struct_vec)) }.to_vec()
.to_vec()
} }
fn to_struct_slice<T>(bytes: &[u8]) -> &[T] { fn to_struct_slice<T>(bytes: &[u8]) -> &[T] {
unsafe { slice::from_raw_parts(bytes.as_ptr() as *const T, bytes.len() / mem::size_of::<T>()) } unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len() / mem::size_of::<T>()) }
} }
fn to_struct_mut_slice<T>(bytes: &mut [u8]) -> &mut [T] { fn to_struct_mut_slice<T>(bytes: &mut [u8]) -> &mut [T] {
unsafe { unsafe {
slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut T, bytes.len() / mem::size_of::<T>()) slice::from_raw_parts_mut(bytes.as_mut_ptr().cast(), bytes.len() / mem::size_of::<T>())
} }
} }
fn elf_hash(name: &[u8]) -> u32 { fn elf_hash(name: &[u8]) -> u32 {
let mut h: u32 = 0; let mut h: u32 = 0;
for c in name { for c in name {
h = (h << 4) + *c as u32; h = (h << 4) + u32::from(*c);
let g = h & 0xf0000000; let g = h & 0xf000_0000;
if g != 0 { if g != 0 {
h ^= g >> 24; h ^= g >> 24;
h &= !g; h &= !g;
@ -202,22 +225,26 @@ impl<'a> Linker<'a> {
relocs: &[R], relocs: &[R],
target_section: Elf32_Word, target_section: Elf32_Word,
) -> Result<(), Error> { ) -> Result<(), Error> {
type RelocateFn = dyn Fn(&mut [u8], Elf32_Word);
struct RelocInfo<'a, R> {
pub defined_val: bool,
pub indirect_reloc: Option<&'a R>,
pub pc_relative: bool,
pub relocate: Option<Box<RelocateFn>>,
}
for reloc in relocs { for reloc in relocs {
let sym = match reloc.sym_info() as usize { let sym = match reloc.sym_info() as usize {
STN_UNDEF => None, STN_UNDEF => None,
sym_index => Some( sym_index => {
self.symtab Some(self.symtab.get(sym_index).ok_or("symbol out of bounds of symbol table")?)
.get(sym_index) }
.ok_or("symbol out of bounds of symbol table")?,
),
}; };
let resolve_symbol_addr = let resolve_symbol_addr =
|sym_option: Option<&Elf32_Sym>| -> Result<Elf32_Word, Error> { |sym_option: Option<&Elf32_Sym>| -> Result<Elf32_Word, Error> {
let sym = match sym_option { let Some(sym) = sym_option else { return Ok(0) };
Some(sym) => sym,
None => return Ok(0),
};
match sym.st_shndx { match sym.st_shndx {
SHN_UNDEF => Err(Error::Lookup("undefined symbol")), SHN_UNDEF => Err(Error::Lookup("undefined symbol")),
@ -244,13 +271,6 @@ impl<'a> Linker<'a> {
.ok_or(Error::Parsing("Cannot find section with matching sh_index")) .ok_or(Error::Parsing("Cannot find section with matching sh_index"))
}; };
struct RelocInfo<'a, R> {
pub defined_val: bool,
pub indirect_reloc: Option<&'a R>,
pub pc_relative: bool,
pub relocate: Option<Box<dyn Fn(&mut [u8], Elf32_Word)>>,
}
let classify = |reloc: &R, sym_option: Option<&Elf32_Sym>| -> Option<RelocInfo<R>> { let classify = |reloc: &R, sym_option: Option<&Elf32_Sym>| -> Option<RelocInfo<R>> {
let defined_val = sym_option.map_or(true, |sym| { let defined_val = sym_option.map_or(true, |sym| {
sym.st_shndx != SHN_UNDEF || ELF32_ST_BIND(sym.st_info) == STB_LOCAL sym.st_shndx != SHN_UNDEF || ELF32_ST_BIND(sym.st_info) == STB_LOCAL
@ -262,7 +282,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: true, pc_relative: true,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(target_word, value) LittleEndian::write_u32(target_word, value);
})), })),
}), }),
@ -273,9 +293,9 @@ impl<'a> Linker<'a> {
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32( LittleEndian::write_u32(
target_word, target_word,
(LittleEndian::read_u32(target_word) & 0x80000000) (LittleEndian::read_u32(target_word) & 0x8000_0000)
| value & 0x7FFFFFFF, | value & 0x7FFF_FFFF,
) );
})), })),
}), }),
@ -297,8 +317,8 @@ impl<'a> Linker<'a> {
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
let auipc_raw = LittleEndian::read_u32(target_word); let auipc_raw = LittleEndian::read_u32(target_word);
let auipc_insn = let auipc_insn =
(auipc_raw & 0xFFF) | ((value + 0x800) & 0xFFFFF000); (auipc_raw & 0xFFF) | ((value + 0x800) & 0xFFFF_F000);
LittleEndian::write_u32(target_word, auipc_insn) LittleEndian::write_u32(target_word, auipc_insn);
})), })),
}) })
} }
@ -308,15 +328,14 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: true, pc_relative: true,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(target_word, value) LittleEndian::write_u32(target_word, value);
})), })),
}), }),
R_RISCV_PCREL_LO12_I => { R_RISCV_PCREL_LO12_I => {
let expected_offset = sym_option.map_or(0, |sym| sym.st_value); let expected_offset = sym_option.map_or(0, |sym| sym.st_value);
let indirect_reloc = relocs let indirect_reloc =
.iter() relocs.iter().find(|reloc| reloc.offset() == expected_offset)?;
.find(|reloc| reloc.offset() == expected_offset)?;
Some(RelocInfo { Some(RelocInfo {
defined_val: { defined_val: {
let indirect_sym = let indirect_sym =
@ -330,14 +349,14 @@ impl<'a> Linker<'a> {
// Here, we convert to direct addressing // Here, we convert to direct addressing
// GOT reloc (indirect) -> lw + addi // GOT reloc (indirect) -> lw + addi
// PCREL reloc (direct) -> addi // PCREL reloc (direct) -> addi
let (lo_opcode, lo_funct3) = (0b0010011, 0b000); let (lo_opcode, lo_funct3) = (0b001_0011, 0b000);
let addi_lw_raw = LittleEndian::read_u32(target_word); let addi_lw_raw = LittleEndian::read_u32(target_word);
let addi_insn = lo_opcode let addi_insn = lo_opcode
| (addi_lw_raw & 0xF8F80) | (addi_lw_raw & 0xF8F80)
| (lo_funct3 << 12) | (lo_funct3 << 12)
| ((value & 0xFFF) << 20); | ((value & 0xFFF) << 20);
LittleEndian::write_u32(target_word, addi_insn) LittleEndian::write_u32(target_word, addi_insn);
})), })),
}) })
} }
@ -354,10 +373,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32( LittleEndian::write_u32(target_word, value);
target_word,
value,
)
})), })),
}), }),
@ -367,7 +383,7 @@ impl<'a> Linker<'a> {
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
let old_value = LittleEndian::read_u32(target_word); let old_value = LittleEndian::read_u32(target_word);
LittleEndian::write_u32(target_word, old_value.wrapping_add(value)) LittleEndian::write_u32(target_word, old_value.wrapping_add(value));
})), })),
}), }),
@ -377,7 +393,7 @@ impl<'a> Linker<'a> {
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
let old_value = LittleEndian::read_u32(target_word); let old_value = LittleEndian::read_u32(target_word);
LittleEndian::write_u32(target_word, old_value.wrapping_sub(value)) LittleEndian::write_u32(target_word, old_value.wrapping_sub(value));
})), })),
}), }),
@ -386,10 +402,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u16( LittleEndian::write_u16(target_word, value as u16);
target_word,
value as u16,
)
})), })),
}), }),
@ -402,7 +415,7 @@ impl<'a> Linker<'a> {
LittleEndian::write_u16( LittleEndian::write_u16(
target_word, target_word,
old_value.wrapping_add(value as u16), old_value.wrapping_add(value as u16),
) );
})), })),
}), }),
@ -415,7 +428,7 @@ impl<'a> Linker<'a> {
LittleEndian::write_u16( LittleEndian::write_u16(
target_word, target_word,
old_value.wrapping_sub(value as u16), old_value.wrapping_sub(value as u16),
) );
})), })),
}), }),
@ -497,7 +510,7 @@ impl<'a> Linker<'a> {
if let Some(relocate) = reloc_info.relocate { if let Some(relocate) = reloc_info.relocate {
let target_word = &mut target_sec_image[reloc.offset() as usize..]; let target_word = &mut target_sec_image[reloc.offset() as usize..];
relocate(target_word, value) relocate(target_word, value);
} else { } else {
self.rela_dyn_relas.push(Elf32_Rela { self.rela_dyn_relas.push(Elf32_Rela {
r_offset: rela_off, r_offset: rela_off,
@ -545,16 +558,18 @@ impl<'a> Linker<'a> {
let eh_frame_slice = eh_frame_rec.data.as_slice(); let eh_frame_slice = eh_frame_rec.data.as_slice();
// Prepare a new buffer to dodge borrow check // Prepare a new buffer to dodge borrow check
let mut eh_frame_hdr_vec: Vec<u8> = vec![0; eh_frame_hdr_rec.shdr.sh_size as usize]; let mut eh_frame_hdr_vec: Vec<u8> = vec![0; eh_frame_hdr_rec.shdr.sh_size as usize];
let eh_frame = EH_Frame::new(eh_frame_slice, eh_frame_rec.shdr.sh_offset) let eh_frame = EH_Frame::new(eh_frame_slice, eh_frame_rec.shdr.sh_offset);
.map_err(|()| "cannot read EH frame")?;
let mut eh_frame_hdr = EH_Frame_Hdr::new( let mut eh_frame_hdr = EH_Frame_Hdr::new(
eh_frame_hdr_vec.as_mut_slice(), eh_frame_hdr_vec.as_mut_slice(),
eh_frame_hdr_rec.shdr.sh_offset, eh_frame_hdr_rec.shdr.sh_offset,
eh_frame_rec.shdr.sh_offset, eh_frame_rec.shdr.sh_offset,
); );
eh_frame.cfi_records() eh_frame.cfi_records().flat_map(|cfi| cfi.fde_records()).for_each(&mut |(
.flat_map(|cfi| cfi.fde_records()) init_pos,
.for_each(&mut |(init_pos, virt_addr)| eh_frame_hdr.add_fde(init_pos, virt_addr)); virt_addr,
)| {
eh_frame_hdr.add_fde(init_pos, virt_addr);
});
// Sort FDE entries in .eh_frame_hdr // Sort FDE entries in .eh_frame_hdr
eh_frame_hdr.finalize_fde(); eh_frame_hdr.finalize_fde();
@ -568,39 +583,114 @@ impl<'a> Linker<'a> {
} }
pub fn ld(data: &'a [u8]) -> Result<Vec<u8>, Error> { pub fn ld(data: &'a [u8]) -> Result<Vec<u8>, Error> {
let ehdr = read_unaligned::<Elf32_Ehdr>(data, 0).map_err(|()| "cannot read ELF header")?; fn allocate_rela_dyn<R: Relocatable>(
linker: &Linker,
relocs: &[R],
) -> Result<(usize, Vec<u32>), Error> {
let mut alloc_size = 0;
let mut rela_dyn_sym_indices = Vec::new();
for reloc in relocs {
if reloc.sym_info() as usize == STN_UNDEF {
continue;
}
let sym: &Elf32_Sym = linker
.symtab
.get(reloc.sym_info() as usize)
.ok_or("symbol out of bounds of symbol table")?;
match (linker.isa, reloc.type_info()) {
// Absolute address relocations
// A runtime relocation is needed to find the loading address
(Isa::CortexA9, R_ARM_ABS32) | (Isa::RiscV32, R_RISCV_32) => {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// Relative address relocations
// Relay the relocation to the runtime linker only if the symbol is not defined
(Isa::CortexA9, R_ARM_REL32 | R_ARM_PREL31 | R_ARM_TARGET2)
| (
Isa::RiscV32,
R_RISCV_CALL_PLT | R_RISCV_PCREL_HI20 | R_RISCV_GOT_HI20 | R_RISCV_32_PCREL
| R_RISCV_SET32 | R_RISCV_ADD32 | R_RISCV_SUB32 | R_RISCV_SET16
| R_RISCV_ADD16 | R_RISCV_SUB16 | R_RISCV_SET8 | R_RISCV_ADD8
| R_RISCV_SUB8 | R_RISCV_SET6 | R_RISCV_SUB6,
) => {
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// RISC-V: Lower 12-bits relocations
// If the upper 20-bits relocation cannot be resolved,
// this relocation will be relayed to the runtime linker.
(Isa::RiscV32, R_RISCV_PCREL_LO12_I) => {
// Find the HI20 relocation
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == sym.st_value)
.ok_or("malformatted LO12 relocation")?;
let indirect_sym = linker.symtab[indirect_reloc.sym_info() as usize];
if ELF32_ST_BIND(indirect_sym.st_info) == STB_GLOBAL
&& indirect_sym.st_shndx == SHN_UNDEF
{
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
_ => {
println!("Relocation type 0x{:X?} is not supported", reloc.type_info());
unimplemented!()
}
}
}
Ok((alloc_size, rela_dyn_sym_indices))
}
let Some(ehdr) = read_unaligned::<Elf32_Ehdr>(data, 0) else {
Err("cannot read ELF header")?
};
let isa = match ehdr.e_machine { let isa = match ehdr.e_machine {
EM_ARM => Isa::CortexA9, EM_ARM => Isa::CortexA9,
EM_RISCV => Isa::RiscV32, EM_RISCV => Isa::RiscV32,
_ => return Err(Error::Parsing("unsupported architecture")), _ => return Err(Error::Parsing("unsupported architecture")),
}; };
let shdrs = get_ref_slice::<Elf32_Shdr>(data, ehdr.e_shoff as usize, ehdr.e_shnum as usize) let Some(shdrs) =
.map_err(|()| "cannot read section header table")?; get_ref_slice::<Elf32_Shdr>(data, ehdr.e_shoff as usize, ehdr.e_shnum as usize)
else {
Err("cannot read section header table")?
};
// Read .strtab // Read .strtab
let strtab_shdr = shdrs[ehdr.e_shstrndx as usize]; let strtab_shdr = shdrs[ehdr.e_shstrndx as usize];
let strtab = let Some(strtab) =
get_ref_slice::<u8>(data, strtab_shdr.sh_offset as usize, strtab_shdr.sh_size as usize) get_ref_slice::<u8>(data, strtab_shdr.sh_offset as usize, strtab_shdr.sh_size as usize)
.map_err(|()| "cannot read the string table from data")?; else {
Err("cannot read the string table from data")?
};
// Read .symtab // Read .symtab
let symtab_shdr = shdrs let symtab_shdr = shdrs
.iter() .iter()
.find(|shdr| shdr.sh_type as usize == SHT_SYMTAB) .find(|shdr| shdr.sh_type as usize == SHT_SYMTAB)
.ok_or(Error::Parsing("cannot find the symbol table"))?; .ok_or(Error::Parsing("cannot find the symbol table"))?;
let symtab = get_ref_slice::<Elf32_Sym>( let Some(symtab) = get_ref_slice::<Elf32_Sym>(
data, data,
symtab_shdr.sh_offset as usize, symtab_shdr.sh_offset as usize,
symtab_shdr.sh_size as usize / mem::size_of::<Elf32_Sym>(), symtab_shdr.sh_size as usize / mem::size_of::<Elf32_Sym>(),
) ) else {
.map_err(|()| "cannot read the symbol table from data")?; Err("cannot read the symbol table from data")?
};
// Section table for the .elf paired with the section name // Section table for the .elf paired with the section name
// To be formalized incrementally // To be formalized incrementally
// Very hashmap-like structure, but the order matters, so it is a vector // Very hashmap-like structure, but the order matters, so it is a vector
let elf_shdrs = vec![ let elf_shdrs = vec![SectionRecord {
SectionRecord {
shdr: Elf32_Shdr { shdr: Elf32_Shdr {
sh_name: 0, sh_name: 0,
sh_type: 0, sh_type: 0,
@ -615,8 +705,7 @@ impl<'a> Linker<'a> {
}, },
name: "", name: "",
data: vec![0; 0], data: vec![0; 0],
}, }];
];
let elf_sh_data_off = mem::size_of::<Elf32_Ehdr>() + mem::size_of::<Elf32_Phdr>() * 5; let elf_sh_data_off = mem::size_of::<Elf32_Ehdr>() + mem::size_of::<Elf32_Phdr>() * 5;
// Image of the linked dynamic library, to be formalized incrementally // Image of the linked dynamic library, to be formalized incrementally
@ -752,21 +841,27 @@ impl<'a> Linker<'a> {
($shdr: expr, $stmt: expr) => { ($shdr: expr, $stmt: expr) => {
match $shdr.sh_type as usize { match $shdr.sh_type as usize {
SHT_RELA => { SHT_RELA => {
let relocs = get_ref_slice::<Elf32_Rela>( let Some(relocs) = get_ref_slice::<Elf32_Rela>(
data, data,
$shdr.sh_offset as usize, $shdr.sh_offset as usize,
$shdr.sh_size as usize / mem::size_of::<Elf32_Rela>(), $shdr.sh_size as usize / mem::size_of::<Elf32_Rela>(),
) ) else {
.map_err(|()| "cannot parse relocations")?; Err("cannot parse relocations")?
};
#[allow(clippy::redundant_closure_call)]
$stmt(relocs) $stmt(relocs)
} }
SHT_REL => { SHT_REL => {
let relocs = get_ref_slice::<Elf32_Rel>( let Some(relocs) = get_ref_slice::<Elf32_Rel>(
data, data,
$shdr.sh_offset as usize, $shdr.sh_offset as usize,
$shdr.sh_size as usize / mem::size_of::<Elf32_Rel>(), $shdr.sh_size as usize / mem::size_of::<Elf32_Rel>(),
) ) else {
.map_err(|()| "cannot parse relocations")?; Err("cannot parse relocations")?
};
#[allow(clippy::redundant_closure_call)]
$stmt(relocs) $stmt(relocs)
} }
_ => unreachable!(), _ => unreachable!(),
@ -774,84 +869,6 @@ impl<'a> Linker<'a> {
}; };
} }
fn allocate_rela_dyn<R: Relocatable>(
linker: &Linker,
relocs: &[R],
) -> Result<(usize, Vec<u32>), Error> {
let mut alloc_size = 0;
let mut rela_dyn_sym_indices = Vec::new();
for reloc in relocs {
if reloc.sym_info() as usize == STN_UNDEF {
continue;
}
let sym: &Elf32_Sym = linker
.symtab
.get(reloc.sym_info() as usize)
.ok_or("symbol out of bounds of symbol table")?;
match (linker.isa, reloc.type_info()) {
// Absolute address relocations
// A runtime relocation is needed to find the loading address
(Isa::CortexA9, R_ARM_ABS32) | (Isa::RiscV32, R_RISCV_32) => {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// Relative address relocations
// Relay the relocation to the runtime linker only if the symbol is not defined
(Isa::CortexA9, R_ARM_REL32)
| (Isa::CortexA9, R_ARM_PREL31)
| (Isa::CortexA9, R_ARM_TARGET2)
| (Isa::RiscV32, R_RISCV_CALL_PLT)
| (Isa::RiscV32, R_RISCV_PCREL_HI20)
| (Isa::RiscV32, R_RISCV_GOT_HI20)
| (Isa::RiscV32, R_RISCV_32_PCREL)
| (Isa::RiscV32, R_RISCV_SET32)
| (Isa::RiscV32, R_RISCV_ADD32)
| (Isa::RiscV32, R_RISCV_SUB32)
| (Isa::RiscV32, R_RISCV_SET16)
| (Isa::RiscV32, R_RISCV_ADD16)
| (Isa::RiscV32, R_RISCV_SUB16)
| (Isa::RiscV32, R_RISCV_SET8)
| (Isa::RiscV32, R_RISCV_ADD8)
| (Isa::RiscV32, R_RISCV_SUB8)
| (Isa::RiscV32, R_RISCV_SET6)
| (Isa::RiscV32, R_RISCV_SUB6) => {
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// RISC-V: Lower 12-bits relocations
// If the upper 20-bits relocation cannot be resolved,
// this relocation will be relayed to the runtime linker.
(Isa::RiscV32, R_RISCV_PCREL_LO12_I) => {
// Find the HI20 relocation
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == sym.st_value)
.ok_or("malformatted LO12 relocation")?;
let indirect_sym = linker.symtab[indirect_reloc.sym_info() as usize];
if ELF32_ST_BIND(indirect_sym.st_info) == STB_GLOBAL
&& indirect_sym.st_shndx == SHN_UNDEF
{
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
_ => {
println!("Relocation type 0x{:X?} is not supported", reloc.type_info());
unimplemented!()
}
}
}
Ok((alloc_size, rela_dyn_sym_indices))
}
for shdr in shdrs for shdr in shdrs
.iter() .iter()
.filter(|shdr| shdr.sh_type as usize == SHT_REL || shdr.sh_type as usize == SHT_RELA) .filter(|shdr| shdr.sh_type as usize == SHT_REL || shdr.sh_type as usize == SHT_RELA)
@ -879,7 +896,7 @@ impl<'a> Linker<'a> {
} }
// Avoid symbol duplication // Avoid symbol duplication
rela_dyn_sym_indices.sort(); rela_dyn_sym_indices.sort_unstable();
rela_dyn_sym_indices.dedup(); rela_dyn_sym_indices.dedup();
if rela_dyn_size != 0 { if rela_dyn_size != 0 {
@ -1010,7 +1027,9 @@ impl<'a> Linker<'a> {
let mut hash_bucket: Vec<u32> = vec![0; dynsym.len()]; let mut hash_bucket: Vec<u32> = vec![0; dynsym.len()];
let mut hash_chain: Vec<u32> = vec![0; dynsym.len()]; let mut hash_chain: Vec<u32> = vec![0; dynsym.len()];
for (sym_index, (str_start, str_end)) in dynsym_names.iter().enumerate().take(dynsym.len()).skip(1) { for (sym_index, (str_start, str_end)) in
dynsym_names.iter().enumerate().take(dynsym.len()).skip(1)
{
let hash = elf_hash(&dynstr[*str_start..*str_end]); let hash = elf_hash(&dynstr[*str_start..*str_end]);
let mut hash_index = hash as usize % hash_bucket.len(); let mut hash_index = hash as usize % hash_bucket.len();
@ -1062,7 +1081,7 @@ impl<'a> Linker<'a> {
sh_entsize: mem::size_of::<Elf32_Sym>() as Elf32_Word, sh_entsize: mem::size_of::<Elf32_Sym>() as Elf32_Word,
}, },
".dynsym", ".dynsym",
from_struct_vec(dynsym), from_struct_slice(&dynsym),
); );
let hash_elf_index = linker.load_section( let hash_elf_index = linker.load_section(
&Elf32_Shdr { &Elf32_Shdr {
@ -1078,7 +1097,7 @@ impl<'a> Linker<'a> {
sh_entsize: 4, sh_entsize: 4,
}, },
".hash", ".hash",
from_struct_vec(hash), from_struct_slice(&hash),
); );
// Link .rela.dyn header to the .dynsym header // Link .rela.dyn header to the .dynsym header
@ -1177,7 +1196,7 @@ impl<'a> Linker<'a> {
}; };
let dynamic_elf_index = let dynamic_elf_index =
linker.load_section(&dynamic_shdr, ".dynamic", from_struct_vec(dyn_entries)); linker.load_section(&dynamic_shdr, ".dynamic", from_struct_slice(&dyn_entries));
let last_w_sec_elf_index = linker.elf_shdrs.len() - 1; let last_w_sec_elf_index = linker.elf_shdrs.len() - 1;
@ -1253,7 +1272,9 @@ impl<'a> Linker<'a> {
update_dynsym_record!(b"__bss_start", bss_offset, bss_elf_index as Elf32_Section); update_dynsym_record!(b"__bss_start", bss_offset, bss_elf_index as Elf32_Section);
update_dynsym_record!(b"_end", bss_offset, bss_elf_index as Elf32_Section); update_dynsym_record!(b"_end", bss_offset, bss_elf_index as Elf32_Section);
} else { } else {
for (bss_iter_index, &(bss_section_index, section_name)) in bss_index_vec.iter().enumerate() { for (bss_iter_index, &(bss_section_index, section_name)) in
bss_index_vec.iter().enumerate()
{
let shdr = &shdrs[bss_section_index]; let shdr = &shdrs[bss_section_index];
let bss_elf_index = linker.load_section( let bss_elf_index = linker.load_section(
shdr, shdr,
@ -1326,7 +1347,7 @@ impl<'a> Linker<'a> {
// Prepare a STRTAB to hold the names of section headers // Prepare a STRTAB to hold the names of section headers
// Fix the sh_name field of the section headers // Fix the sh_name field of the section headers
let mut shstrtab = Vec::new(); let mut shstrtab = Vec::new();
for shdr_rec in linker.elf_shdrs.iter_mut() { for shdr_rec in &mut linker.elf_shdrs {
let shstrtab_index = shstrtab.len(); let shstrtab_index = shstrtab.len();
shstrtab.extend(shdr_rec.name.as_bytes()); shstrtab.extend(shdr_rec.name.as_bytes());
shstrtab.push(0); shstrtab.push(0);
@ -1367,20 +1388,17 @@ impl<'a> Linker<'a> {
let alignment = (4 - (linker.image.len() % 4)) % 4; let alignment = (4 - (linker.image.len() % 4)) % 4;
let sec_headers_offset = linker.image.len() + alignment; let sec_headers_offset = linker.image.len() + alignment;
linker.image.extend(vec![0; alignment]); linker.image.extend(vec![0; alignment]);
for rec in linker.elf_shdrs.iter() { for rec in &linker.elf_shdrs {
let shdr = rec.shdr; let shdr = rec.shdr;
linker.image.extend(unsafe { linker.image.extend(unsafe {
slice::from_raw_parts( slice::from_raw_parts(ptr::addr_of!(shdr).cast(), mem::size_of::<Elf32_Shdr>())
&shdr as *const Elf32_Shdr as *const u8,
mem::size_of::<Elf32_Shdr>(),
)
}); });
} }
// Update the PHDRs // Update the PHDRs
let phdr_offset = mem::size_of::<Elf32_Ehdr>(); let phdr_offset = mem::size_of::<Elf32_Ehdr>();
unsafe { unsafe {
let phdr_ptr = linker.image.as_mut_ptr().add(phdr_offset) as *mut Elf32_Phdr; let phdr_ptr = linker.image.as_mut_ptr().add(phdr_offset).cast();
let phdr_slice = slice::from_raw_parts_mut(phdr_ptr, 5); let phdr_slice = slice::from_raw_parts_mut(phdr_ptr, 5);
// List of program headers: // List of program headers:
// 1. ELF headers & program headers // 1. ELF headers & program headers
@ -1457,7 +1475,7 @@ impl<'a> Linker<'a> {
} }
// Update the EHDR // Update the EHDR
let ehdr_ptr = linker.image.as_mut_ptr() as *mut Elf32_Ehdr; let ehdr_ptr = linker.image.as_mut_ptr().cast();
unsafe { unsafe {
*ehdr_ptr = Elf32_Ehdr { *ehdr_ptr = Elf32_Ehdr {
e_ident: ehdr.e_ident, e_ident: ehdr.e_ident,

View File

@ -16,7 +16,7 @@ lalrpop-util = "0.20"
log = "0.4" log = "0.4"
unic-emoji-char = "0.9" unic-emoji-char = "0.9"
unic-ucd-ident = "0.9" unic-ucd-ident = "0.9"
unicode_names2 = "1.0" unicode_names2 = "1.2"
phf = { version = "0.11", features = ["macros"] } phf = { version = "0.11", features = ["macros"] }
ahash = "0.8" ahash = "0.8"

View File

@ -1,15 +1,15 @@
use lalrpop_util::ParseError;
use nac3ast::*;
use crate::ast::Ident; use crate::ast::Ident;
use crate::ast::Location; use crate::ast::Location;
use crate::token::Tok;
use crate::error::*; use crate::error::*;
use crate::token::Tok;
use lalrpop_util::ParseError;
use nac3ast::*;
pub fn make_config_comment( pub fn make_config_comment(
com_loc: Location, com_loc: Location,
stmt_loc: Location, stmt_loc: Location,
nac3com_above: Vec<(Ident, Tok)>, nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident> nac3com_end: Option<Ident>,
) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> { ) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> {
if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() { if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() {
return Err(ParseError::User { return Err(ParseError::User {
@ -17,24 +17,25 @@ pub fn make_config_comment(
location: com_loc, location: com_loc,
error: LexicalErrorType::OtherError( error: LexicalErrorType::OtherError(
format!( format!(
"config comment at top must have the same indentation with what it applies (comment at {}, statement at {})", "config comment at top must have the same indentation with what it applies (comment at {com_loc}, statement at {stmt_loc})",
com_loc,
stmt_loc,
) )
) )
} }
}) });
}; };
Ok( Ok(nac3com_above
nac3com_above
.into_iter() .into_iter()
.map(|(com, _)| com) .map(|(com, _)| com)
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter())) .chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
.collect() .collect())
)
} }
pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option<Ident>, com_above_loc: Location) -> Result<(), ParseError<Location, Tok, LexicalError>> { pub fn handle_small_stmt<U>(
stmts: &mut [Stmt<U>],
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>,
com_above_loc: Location,
) -> Result<(), ParseError<Location, Tok, LexicalError>> {
if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() { if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() {
return Err(ParseError::User { return Err(ParseError::User {
error: LexicalError { error: LexicalError {
@ -47,17 +48,12 @@ pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, To
) )
) )
} }
}) });
} }
apply_config_comments( apply_config_comments(&mut stmts[0], nac3com_above.into_iter().map(|(com, _)| com).collect());
&mut stmts[0],
nac3com_above
.into_iter()
.map(|(com, _)| com).collect()
);
apply_config_comments( apply_config_comments(
stmts.last_mut().unwrap(), stmts.last_mut().unwrap(),
nac3com_end.map_or_else(Vec::new, |com| vec![com]) nac3com_end.map_or_else(Vec::new, |com| vec![com]),
); );
Ok(()) Ok(())
} }
@ -80,6 +76,8 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
| StmtKind::Nonlocal { config_comment, .. } | StmtKind::Nonlocal { config_comment, .. }
| StmtKind::Assert { config_comment, .. } => config_comment.extend(comments), | StmtKind::Assert { config_comment, .. } => config_comment.extend(comments),
_ => { unreachable!("only small statements should call this function") } _ => {
unreachable!("only small statements should call this function")
}
} }
} }

View File

@ -37,7 +37,7 @@ impl fmt::Display for LexicalErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
LexicalErrorType::StringError => write!(f, "Got unexpected string"), LexicalErrorType::StringError => write!(f, "Got unexpected string"),
LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {}", error), LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {error}"),
LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"), LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"),
LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"), LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"),
LexicalErrorType::IndentationError => { LexicalErrorType::IndentationError => {
@ -59,13 +59,13 @@ impl fmt::Display for LexicalErrorType {
write!(f, "positional argument follows keyword argument") write!(f, "positional argument follows keyword argument")
} }
LexicalErrorType::UnrecognizedToken { tok } => { LexicalErrorType::UnrecognizedToken { tok } => {
write!(f, "Got unexpected token {}", tok) write!(f, "Got unexpected token {tok}")
} }
LexicalErrorType::LineContinuationError => { LexicalErrorType::LineContinuationError => {
write!(f, "unexpected character after line continuation character") write!(f, "unexpected character after line continuation character")
} }
LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"), LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"),
LexicalErrorType::OtherError(msg) => write!(f, "{}", msg), LexicalErrorType::OtherError(msg) => write!(f, "{msg}"),
} }
} }
} }
@ -96,7 +96,7 @@ impl fmt::Display for FStringErrorType {
FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"), FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"),
FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."), FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."),
FStringErrorType::InvalidExpression(error) => { FStringErrorType::InvalidExpression(error) => {
write!(f, "Invalid expression: {}", error) write!(f, "Invalid expression: {error}")
} }
FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"), FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"),
FStringErrorType::EmptyExpression => write!(f, "Empty expression"), FStringErrorType::EmptyExpression => write!(f, "Empty expression"),
@ -144,36 +144,27 @@ pub enum ParseErrorType {
impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError { impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError {
fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self { fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self {
match err { match err {
// TODO: Are there cases where this isn't an EOF? LalrpopError::ExtraToken { token } => {
LalrpopError::InvalidToken { location } => ParseError { ParseError { error: ParseErrorType::ExtraToken(token.1), location: token.0 }
error: ParseErrorType::Eof, }
location, LalrpopError::User { error } => {
}, ParseError { error: ParseErrorType::Lexical(error.error), location: error.location }
LalrpopError::ExtraToken { token } => ParseError { }
error: ParseErrorType::ExtraToken(token.1),
location: token.0,
},
LalrpopError::User { error } => ParseError {
error: ParseErrorType::Lexical(error.error),
location: error.location,
},
LalrpopError::UnrecognizedToken { token, expected } => { LalrpopError::UnrecognizedToken { token, expected } => {
// Hacky, but it's how CPython does it. See PyParser_AddToken, // Hacky, but it's how CPython does it. See PyParser_AddToken,
// in particular "Only one possible expected token" comment. // in particular "Only one possible expected token" comment.
let expected = if expected.len() == 1 { let expected = if expected.len() == 1 { Some(expected[0].clone()) } else { None };
Some(expected[0].clone())
} else {
None
};
ParseError { ParseError {
error: ParseErrorType::UnrecognizedToken(token.1, expected), error: ParseErrorType::UnrecognizedToken(token.1, expected),
location: token.0, location: token.0,
} }
} }
LalrpopError::UnrecognizedEof { location, .. } => ParseError {
error: ParseErrorType::Eof, LalrpopError::UnrecognizedEof { location, .. }
location, // TODO: Are there cases where this isn't an EOF?
}, | LalrpopError::InvalidToken { location } => {
ParseError { error: ParseErrorType::Eof, location }
}
} }
} }
} }
@ -188,7 +179,7 @@ impl fmt::Display for ParseErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
ParseErrorType::Eof => write!(f, "Got unexpected EOF"), ParseErrorType::Eof => write!(f, "Got unexpected EOF"),
ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {:?}", tok), ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {tok:?}"),
ParseErrorType::InvalidToken => write!(f, "Got invalid token"), ParseErrorType::InvalidToken => write!(f, "Got invalid token"),
ParseErrorType::UnrecognizedToken(ref tok, ref expected) => { ParseErrorType::UnrecognizedToken(ref tok, ref expected) => {
if *tok == Tok::Indent { if *tok == Tok::Indent {
@ -196,10 +187,10 @@ impl fmt::Display for ParseErrorType {
} else if expected.as_deref() == Some("Indent") { } else if expected.as_deref() == Some("Indent") {
write!(f, "expected an indented block") write!(f, "expected an indented block")
} else { } else {
write!(f, "Got unexpected token {}", tok) write!(f, "Got unexpected token {tok}")
} }
} }
ParseErrorType::Lexical(ref error) => write!(f, "{}", error), ParseErrorType::Lexical(ref error) => write!(f, "{error}"),
} }
} }
} }
@ -207,6 +198,7 @@ impl fmt::Display for ParseErrorType {
impl Error for ParseErrorType {} impl Error for ParseErrorType {}
impl ParseErrorType { impl ParseErrorType {
#[must_use]
pub fn is_indentation_error(&self) -> bool { pub fn is_indentation_error(&self) -> bool {
match self { match self {
ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true, ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true,
@ -216,11 +208,11 @@ impl ParseErrorType {
_ => false, _ => false,
} }
} }
#[must_use]
pub fn is_tab_error(&self) -> bool { pub fn is_tab_error(&self) -> bool {
matches!( matches!(
self, self,
ParseErrorType::Lexical(LexicalErrorType::TabError) ParseErrorType::Lexical(LexicalErrorType::TabError | LexicalErrorType::TabsAfterSpaces)
| ParseErrorType::Lexical(LexicalErrorType::TabsAfterSpaces)
) )
} }
} }

View File

@ -15,10 +15,7 @@ struct FStringParser<'a> {
impl<'a> FStringParser<'a> { impl<'a> FStringParser<'a> {
fn new(source: &'a str, str_location: Location) -> Self { fn new(source: &'a str, str_location: Location) -> Self {
Self { Self { chars: source.chars().peekable(), str_location }
chars: source.chars().peekable(),
str_location,
}
} }
#[inline] #[inline]
@ -133,10 +130,10 @@ impl<'a> FStringParser<'a> {
) )
} else { } else {
Box::new(self.expr(ExprKind::Constant { Box::new(self.expr(ExprKind::Constant {
value: spec_expression.to_owned().into(), value: spec_expression.clone().into(),
kind: None, kind: None,
})) }))
}) });
} }
'(' | '{' | '[' => { '(' | '{' | '[' => {
expression.push(ch); expression.push(ch);
@ -251,17 +248,11 @@ impl<'a> FStringParser<'a> {
} }
if !content.is_empty() { if !content.is_empty() {
values.push(self.expr(ExprKind::Constant { values.push(self.expr(ExprKind::Constant { value: content.into(), kind: None }));
value: content.into(),
kind: None,
}))
} }
let s = match values.len() { let s = match values.len() {
0 => self.expr(ExprKind::Constant { 0 => self.expr(ExprKind::Constant { value: String::new().into(), kind: None }),
value: String::new().into(),
kind: None,
}),
1 => values.into_iter().next().unwrap(), 1 => values.into_iter().next().unwrap(),
_ => self.expr(ExprKind::JoinedStr { values }), _ => self.expr(ExprKind::JoinedStr { values }),
}; };
@ -270,16 +261,14 @@ impl<'a> FStringParser<'a> {
} }
fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> { fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> {
let fstring_body = format!("({})", source); let fstring_body = format!("({source})");
parse_expression(&fstring_body) parse_expression(&fstring_body)
} }
/// Parse an fstring from a string, located at a certain position in the sourcecode. /// Parse an fstring from a string, located at a certain position in the sourcecode.
/// In case of errors, we will get the location and the error returned. /// In case of errors, we will get the location and the error returned.
pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> { pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> {
FStringParser::new(source, location) FStringParser::new(source, location).parse().map_err(|error| FStringError { error, location })
.parse()
.map_err(|error| FStringError { error, location })
} }
#[cfg(test)] #[cfg(test)]
@ -293,7 +282,7 @@ mod tests {
#[test] #[test]
fn test_parse_fstring() { fn test_parse_fstring() {
let source = "{a}{ b }{{foo}}"; let source = "{a}{ b }{{foo}}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -301,7 +290,7 @@ mod tests {
#[test] #[test]
fn test_parse_fstring_nested_spec() { fn test_parse_fstring_nested_spec() {
let source = "{foo:{spec}}"; let source = "{foo:{spec}}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -309,7 +298,7 @@ mod tests {
#[test] #[test]
fn test_parse_fstring_not_nested_spec() { fn test_parse_fstring_not_nested_spec() {
let source = "{foo:spec}"; let source = "{foo:spec}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -322,7 +311,7 @@ mod tests {
#[test] #[test]
fn test_fstring_parse_selfdocumenting_base() { fn test_fstring_parse_selfdocumenting_base() {
let src = "{user=}"; let src = "{user=}";
let parse_ast = parse_fstring(&src).unwrap(); let parse_ast = parse_fstring(src).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -330,7 +319,7 @@ mod tests {
#[test] #[test]
fn test_fstring_parse_selfdocumenting_base_more() { fn test_fstring_parse_selfdocumenting_base_more() {
let src = "mix {user=} with text and {second=}"; let src = "mix {user=} with text and {second=}";
let parse_ast = parse_fstring(&src).unwrap(); let parse_ast = parse_fstring(src).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -338,7 +327,7 @@ mod tests {
#[test] #[test]
fn test_fstring_parse_selfdocumenting_format() { fn test_fstring_parse_selfdocumenting_format() {
let src = "{user=:>10}"; let src = "{user=:>10}";
let parse_ast = parse_fstring(&src).unwrap(); let parse_ast = parse_fstring(src).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -371,35 +360,35 @@ mod tests {
#[test] #[test]
fn test_parse_fstring_not_equals() { fn test_parse_fstring_not_equals() {
let source = "{1 != 2}"; let source = "{1 != 2}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_equals() { fn test_parse_fstring_equals() {
let source = "{42 == 42}"; let source = "{42 == 42}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_selfdoc_prec_space() { fn test_parse_fstring_selfdoc_prec_space() {
let source = "{x =}"; let source = "{x =}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_selfdoc_trailing_space() { fn test_parse_fstring_selfdoc_trailing_space() {
let source = "{x= }"; let source = "{x= }";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_yield_expr() { fn test_parse_fstring_yield_expr() {
let source = "{yield}"; let source = "{yield}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
} }

View File

@ -54,8 +54,7 @@ pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, Lexi
let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new()); let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new());
for (name, value) in func_args { for (name, value) in func_args {
match name { if let Some((location, name)) = name {
Some((location, name)) => {
if let Some(keyword_name) = &name { if let Some(keyword_name) = &name {
if keyword_names.contains(keyword_name) { if keyword_names.contains(keyword_name) {
return Err(LexicalError { return Err(LexicalError {
@ -69,13 +68,9 @@ pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, Lexi
keywords.push(ast::Keyword::new( keywords.push(ast::Keyword::new(
location, location,
ast::KeywordData { ast::KeywordData { arg: name.map(String::into), value: Box::new(value) },
arg: name.map(|name| name.into()),
value: Box::new(value),
},
)); ));
} } else {
None => {
// Allow starred args after keyword arguments. // Allow starred args after keyword arguments.
if !keywords.is_empty() && !is_starred(&value) { if !keywords.is_empty() && !is_starred(&value) {
return Err(LexicalError { return Err(LexicalError {
@ -87,7 +82,6 @@ pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, Lexi
args.push(value); args.push(value);
} }
} }
}
Ok(ArgumentList { args, keywords }) Ok(ArgumentList { args, keywords })
} }

View File

@ -3,12 +3,12 @@
//! This means source code is translated into separate tokens. //! This means source code is translated into separate tokens.
pub use super::token::Tok; pub use super::token::Tok;
use crate::ast::{Location, FileName}; use crate::ast::{FileName, Location};
use crate::error::{LexicalError, LexicalErrorType}; use crate::error::{LexicalError, LexicalErrorType};
use std::char; use std::char;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::str::FromStr;
use std::num::IntErrorKind; use std::num::IntErrorKind;
use std::str::FromStr;
use unic_emoji_char::is_emoji_presentation; use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start}; use unic_ucd_ident::{is_xid_continue, is_xid_start};
@ -32,20 +32,14 @@ impl IndentationLevel {
if self.spaces <= other.spaces { if self.spaces <= other.spaces {
Ok(Ordering::Less) Ok(Ordering::Less)
} else { } else {
Err(LexicalError { Err(LexicalError { location, error: LexicalErrorType::TabError })
location,
error: LexicalErrorType::TabError,
})
} }
} }
Ordering::Greater => { Ordering::Greater => {
if self.spaces >= other.spaces { if self.spaces >= other.spaces {
Ok(Ordering::Greater) Ok(Ordering::Greater)
} else { } else {
Err(LexicalError { Err(LexicalError { location, error: LexicalErrorType::TabError })
location,
error: LexicalErrorType::TabError,
})
} }
} }
Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)), Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)),
@ -63,7 +57,7 @@ pub struct Lexer<T: Iterator<Item = char>> {
chr1: Option<char>, chr1: Option<char>,
chr2: Option<char>, chr2: Option<char>,
location: Location, location: Location,
config_comment_prefix: Option<&'static str> config_comment_prefix: Option<&'static str>,
} }
pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! { pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! {
@ -136,11 +130,7 @@ where
T: Iterator<Item = char>, T: Iterator<Item = char>,
{ {
pub fn new(source: T) -> Self { pub fn new(source: T) -> Self {
let mut nlh = NewlineHandler { let mut nlh = NewlineHandler { source, chr0: None, chr1: None };
source,
chr0: None,
chr1: None,
};
nlh.shift(); nlh.shift();
nlh.shift(); nlh.shift();
nlh nlh
@ -169,7 +159,7 @@ where
self.shift(); self.shift();
} else { } else {
// Transform MAC EOL into \n // Transform MAC EOL into \n
self.chr0 = Some('\n') self.chr0 = Some('\n');
} }
} else { } else {
break; break;
@ -189,13 +179,13 @@ where
chars: input, chars: input,
at_begin_of_line: true, at_begin_of_line: true,
nesting: 0, nesting: 0,
indentation_stack: vec![Default::default()], indentation_stack: vec![IndentationLevel::default()],
pending: Vec::new(), pending: Vec::new(),
chr0: None, chr0: None,
location: start, location: start,
chr1: None, chr1: None,
chr2: None, chr2: None,
config_comment_prefix: Some(" nac3:") config_comment_prefix: Some(" nac3:"),
}; };
lxr.next_char(); lxr.next_char();
lxr.next_char(); lxr.next_char();
@ -217,11 +207,9 @@ where
let mut saw_f = false; let mut saw_f = false;
loop { loop {
// Detect r"", f"", b"" and u"" // Detect r"", f"", b"" and u""
if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b') | Some('B')) { if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b' | 'B')) {
saw_b = true; saw_b = true;
} else if !(saw_b || saw_r || saw_u || saw_f) } else if !(saw_b || saw_r || saw_u || saw_f) && matches!(self.chr0, Some('u' | 'U')) {
&& matches!(self.chr0, Some('u') | Some('U'))
{
saw_u = true; saw_u = true;
} else if !(saw_r || saw_u) && (self.chr0 == Some('r') || self.chr0 == Some('R')) { } else if !(saw_r || saw_u) && (self.chr0 == Some('r') || self.chr0 == Some('R')) {
saw_r = true; saw_r = true;
@ -287,15 +275,15 @@ where
let end_pos = self.get_pos(); let end_pos = self.get_pos();
let value = match i128::from_str_radix(&value_text, radix) { let value = match i128::from_str_radix(&value_text, radix) {
Ok(value) => value, Ok(value) => value,
Err(e) => { Err(e) => match e.kind() {
match e.kind() {
IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX, IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX,
_ => return Err(LexicalError { _ => {
error: LexicalErrorType::OtherError(format!("{:?}", e)), return Err(LexicalError {
error: LexicalErrorType::OtherError(format!("{e:?}")),
location: start_pos, location: start_pos,
}), })
}
} }
},
}; };
Ok((start_pos, Tok::Int { value }, end_pos)) Ok((start_pos, Tok::Int { value }, end_pos))
} }
@ -338,14 +326,7 @@ where
if self.chr0 == Some('j') || self.chr0 == Some('J') { if self.chr0 == Some('j') || self.chr0 == Some('J') {
self.next_char(); self.next_char();
let end_pos = self.get_pos(); let end_pos = self.get_pos();
Ok(( Ok((start_pos, Tok::Complex { real: 0.0, imag: value }, end_pos))
start_pos,
Tok::Complex {
real: 0.0,
imag: value,
},
end_pos,
))
} else { } else {
let end_pos = self.get_pos(); let end_pos = self.get_pos();
Ok((start_pos, Tok::Float { value }, end_pos)) Ok((start_pos, Tok::Float { value }, end_pos))
@ -364,7 +345,7 @@ where
let value = value_text.parse::<i128>().ok(); let value = value_text.parse::<i128>().ok();
let nonzero = match value { let nonzero = match value {
Some(value) => value != 0i128, Some(value) => value != 0i128,
None => true None => true,
}; };
if start_is_zero && nonzero { if start_is_zero && nonzero {
return Err(LexicalError { return Err(LexicalError {
@ -379,7 +360,7 @@ where
/// Consume a sequence of numbers with the given radix, /// Consume a sequence of numbers with the given radix,
/// the digits can be decorated with underscores /// the digits can be decorated with underscores
/// like this: '1_2_3_4' == '1234' /// like this: `'1_2_3_4'` == `'1234'`
fn radix_run(&mut self, radix: u32) -> String { fn radix_run(&mut self, radix: u32) -> String {
let mut value_text = String::new(); let mut value_text = String::new();
@ -412,7 +393,7 @@ where
2 => matches!(c, Some('0'..='1')), 2 => matches!(c, Some('0'..='1')),
8 => matches!(c, Some('0'..='7')), 8 => matches!(c, Some('0'..='7')),
10 => matches!(c, Some('0'..='9')), 10 => matches!(c, Some('0'..='9')),
16 => matches!(c, Some('0'..='9') | Some('a'..='f') | Some('A'..='F')), 16 => matches!(c, Some('0'..='9' | 'a'..='f' | 'A'..='F')),
other => unimplemented!("Radix not implemented: {}", other), other => unimplemented!("Radix not implemented: {}", other),
} }
} }
@ -420,8 +401,8 @@ where
/// Test if we face '[eE][-+]?[0-9]+' /// Test if we face '[eE][-+]?[0-9]+'
fn at_exponent(&self) -> bool { fn at_exponent(&self) -> bool {
match self.chr0 { match self.chr0 {
Some('e') | Some('E') => match self.chr1 { Some('e' | 'E') => match self.chr1 {
Some('+') | Some('-') => matches!(self.chr2, Some('0'..='9')), Some('+' | '-') => matches!(self.chr2, Some('0'..='9')),
Some('0'..='9') => true, Some('0'..='9') => true,
_ => false, _ => false,
}, },
@ -433,19 +414,17 @@ where
fn lex_comment(&mut self) -> Option<Spanned> { fn lex_comment(&mut self) -> Option<Spanned> {
self.next_char(); self.next_char();
// if possibly nac3 pseudocomment, special handling for `# nac3:` // if possibly nac3 pseudocomment, special handling for `# nac3:`
let (mut prefix, mut is_comment) = self let (mut prefix, mut is_comment) =
.config_comment_prefix self.config_comment_prefix.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
// for the correct location of config comment // for the correct location of config comment
let mut start_loc = self.location; let mut start_loc = self.location;
start_loc.go_left(); start_loc.go_left();
loop { loop {
match self.chr0 { match self.chr0 {
Some('\n') => return None, Some('\n') | None => return None,
None => return None,
Some(c) => { Some(c) => {
if let (true, Some(p)) = (is_comment, prefix.next()) { if let (true, Some(p)) = (is_comment, prefix.next()) {
is_comment = is_comment && c == p is_comment = is_comment && c == p;
} else { } else {
// done checking prefix, if is comment then return the spanned // done checking prefix, if is comment then return the spanned
if is_comment { if is_comment {
@ -460,22 +439,20 @@ where
return Some(( return Some((
start_loc, start_loc,
Tok::ConfigComment { content: content.trim().into() }, Tok::ConfigComment { content: content.trim().into() },
self.location self.location,
)); ));
} }
} }
} }
} }
self.next_char(); self.next_char();
}; }
} }
fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> { fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> {
let mut p: u32 = 0u32; let mut p: u32 = 0u32;
let unicode_error = LexicalError { let unicode_error =
error: LexicalErrorType::UnicodeError, LexicalError { error: LexicalErrorType::UnicodeError, location: self.get_pos() };
location: self.get_pos(),
};
for i in 1..=literal_number { for i in 1..=literal_number {
match self.next_char() { match self.next_char() {
Some(c) => match c.to_digit(16) { Some(c) => match c.to_digit(16) {
@ -496,7 +473,7 @@ where
octet_content.push(first); octet_content.push(first);
while octet_content.len() < 3 { while octet_content.len() < 3 {
if let Some('0'..='7') = self.chr0 { if let Some('0'..='7') = self.chr0 {
octet_content.push(self.next_char().unwrap()) octet_content.push(self.next_char().unwrap());
} else { } else {
break; break;
} }
@ -530,10 +507,8 @@ where
} }
} }
} }
unicode_names2::character(&name).ok_or(LexicalError { unicode_names2::character(&name)
error: LexicalErrorType::UnicodeError, .ok_or(LexicalError { error: LexicalErrorType::UnicodeError, location: start_pos })
location: start_pos,
})
} }
fn lex_string( fn lex_string(
@ -566,7 +541,7 @@ where
} else if is_raw { } else if is_raw {
string_content.push('\\'); string_content.push('\\');
if let Some(c) = self.next_char() { if let Some(c) = self.next_char() {
string_content.push(c) string_content.push(c);
} else { } else {
return Err(LexicalError { return Err(LexicalError {
error: LexicalErrorType::StringError, error: LexicalErrorType::StringError,
@ -599,7 +574,7 @@ where
Some('u') if !is_bytes => string_content.push(self.unicode_literal(4)?), Some('u') if !is_bytes => string_content.push(self.unicode_literal(4)?),
Some('U') if !is_bytes => string_content.push(self.unicode_literal(8)?), Some('U') if !is_bytes => string_content.push(self.unicode_literal(8)?),
Some('N') if !is_bytes => { Some('N') if !is_bytes => {
string_content.push(self.parse_unicode_name()?) string_content.push(self.parse_unicode_name()?);
} }
Some(c) => { Some(c) => {
string_content.push('\\'); string_content.push('\\');
@ -650,20 +625,15 @@ where
let end_pos = self.get_pos(); let end_pos = self.get_pos();
let tok = if is_bytes { let tok = if is_bytes {
Tok::Bytes { Tok::Bytes { value: string_content.chars().map(|c| c as u8).collect() }
value: string_content.chars().map(|c| c as u8).collect(),
}
} else { } else {
Tok::String { Tok::String { value: string_content, is_fstring }
value: string_content,
is_fstring,
}
}; };
Ok((start_pos, tok, end_pos)) Ok((start_pos, tok, end_pos))
} }
fn is_identifier_start(&self, c: char) -> bool { fn is_identifier_start(c: char) -> bool {
match c { match c {
'_' | 'a'..='z' | 'A'..='Z' => true, '_' | 'a'..='z' | 'A'..='Z' => true,
'+' | '-' | '*' | '/' | '=' | ' ' | '<' | '>' => false, '+' | '-' | '*' | '/' | '=' | ' ' | '<' | '>' => false,
@ -835,18 +805,14 @@ where
// Check if we have some character: // Check if we have some character:
if let Some(c) = self.chr0 { if let Some(c) = self.chr0 {
// First check identifier: // First check identifier:
if self.is_identifier_start(c) { if Self::is_identifier_start(c) {
let identifier = self.lex_identifier()?; let identifier = self.lex_identifier()?;
self.emit(identifier); self.emit(identifier);
} else if is_emoji_presentation(c) { } else if is_emoji_presentation(c) {
let tok_start = self.get_pos(); let tok_start = self.get_pos();
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit(( self.emit((tok_start, Tok::Name { name: c.to_string().into() }, tok_end));
tok_start,
Tok::Name { name: c.to_string().into() },
tok_end,
));
} else { } else {
self.consume_character(c)?; self.consume_character(c)?;
} }
@ -899,18 +865,15 @@ where
'=' => { '=' => {
let tok_start = self.get_pos(); let tok_start = self.get_pos();
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => {
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::EqEqual, tok_end)); self.emit((tok_start, Tok::EqEqual, tok_end));
} } else {
_ => {
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::Equal, tok_end)); self.emit((tok_start, Tok::Equal, tok_end));
} }
} }
}
'+' => { '+' => {
let tok_start = self.get_pos(); let tok_start = self.get_pos();
self.next_char(); self.next_char();
@ -934,18 +897,15 @@ where
} }
Some('*') => { Some('*') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => {
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleStarEqual, tok_end)); self.emit((tok_start, Tok::DoubleStarEqual, tok_end));
} } else {
_ => {
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleStar, tok_end)); self.emit((tok_start, Tok::DoubleStar, tok_end));
} }
} }
}
_ => { _ => {
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::Star, tok_end)); self.emit((tok_start, Tok::Star, tok_end));
@ -963,18 +923,15 @@ where
} }
Some('/') => { Some('/') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => {
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleSlashEqual, tok_end)); self.emit((tok_start, Tok::DoubleSlashEqual, tok_end));
} } else {
_ => {
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleSlash, tok_end)); self.emit((tok_start, Tok::DoubleSlash, tok_end));
} }
} }
}
_ => { _ => {
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::Slash, tok_end)); self.emit((tok_start, Tok::Slash, tok_end));
@ -1141,18 +1098,15 @@ where
match self.chr0 { match self.chr0 {
Some('<') => { Some('<') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => {
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::LeftShiftEqual, tok_end)); self.emit((tok_start, Tok::LeftShiftEqual, tok_end));
} } else {
_ => {
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::LeftShift, tok_end)); self.emit((tok_start, Tok::LeftShift, tok_end));
} }
} }
}
Some('=') => { Some('=') => {
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
@ -1170,18 +1124,15 @@ where
match self.chr0 { match self.chr0 {
Some('>') => { Some('>') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => {
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::RightShiftEqual, tok_end)); self.emit((tok_start, Tok::RightShiftEqual, tok_end));
} } else {
_ => {
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit((tok_start, Tok::RightShift, tok_end)); self.emit((tok_start, Tok::RightShift, tok_end));
} }
} }
}
Some('=') => { Some('=') => {
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
@ -1333,13 +1284,14 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{make_tokenizer, NewlineHandler, Tok}; use super::{make_tokenizer, NewlineHandler, Tok};
use nac3ast::FileName;
const WINDOWS_EOL: &str = "\r\n"; const WINDOWS_EOL: &str = "\r\n";
const MAC_EOL: &str = "\r"; const MAC_EOL: &str = "\r";
const UNIX_EOL: &str = "\n"; const UNIX_EOL: &str = "\n";
pub fn lex_source(source: &str) -> Vec<Tok> { pub fn lex_source(source: &str) -> Vec<Tok> {
let lexer = make_tokenizer(source, Default::default()); let lexer = make_tokenizer(source, FileName::default());
lexer.map(|x| x.unwrap().1).collect() lexer.map(|x| x.unwrap().1).collect()
} }
@ -1419,7 +1371,7 @@ class Foo(A, B):
Dedent, Dedent,
Dedent Dedent
] ]
) );
} }
#[test] #[test]
@ -1439,14 +1391,8 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::String { Tok::String { value: "\\\\".to_owned(), is_fstring: false },
value: "\\\\".to_owned(), Tok::String { value: "\\".to_owned(), is_fstring: false },
is_fstring: false,
},
Tok::String {
value: "\\".to_owned(),
is_fstring: false,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1459,27 +1405,13 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::Int { Tok::Int { value: 47i128 },
value: 47i128, Tok::Int { value: 13i128 },
}, Tok::Int { value: 0i128 },
Tok::Int { Tok::Int { value: 123i128 },
value: 13i128,
},
Tok::Int {
value: 0i128,
},
Tok::Int {
value: 123i128,
},
Tok::Float { value: 0.2 }, Tok::Float { value: 0.2 },
Tok::Complex { Tok::Complex { real: 0.0, imag: 2.0 },
real: 0.0, Tok::Complex { real: 0.0, imag: 2.2 },
imag: 2.0,
},
Tok::Complex {
real: 0.0,
imag: 2.2,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1539,21 +1471,13 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::Name { Tok::Name { name: String::from("avariable").into() },
name: String::from("avariable").into(),
},
Tok::Equal, Tok::Equal,
Tok::Int { Tok::Int { value: 99i128 },
value: 99i128
},
Tok::Plus, Tok::Plus,
Tok::Int { Tok::Int { value: 2i128 },
value: 2i128
},
Tok::Minus, Tok::Minus,
Tok::Int { Tok::Int { value: 0i128 },
value: 0i128
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1740,42 +1664,15 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::String { Tok::String { value: String::from("double"), is_fstring: false },
value: String::from("double"), Tok::String { value: String::from("single"), is_fstring: false },
is_fstring: false, Tok::String { value: String::from("can't"), is_fstring: false },
}, Tok::String { value: String::from("\\\""), is_fstring: false },
Tok::String { Tok::String { value: String::from("\t\r\n"), is_fstring: false },
value: String::from("single"), Tok::String { value: String::from("\\g"), is_fstring: false },
is_fstring: false, Tok::String { value: String::from("raw\\'"), is_fstring: false },
}, Tok::String { value: String::from("Đ"), is_fstring: false },
Tok::String { Tok::String { value: String::from("\u{80}\u{0}a"), is_fstring: false },
value: String::from("can't"),
is_fstring: false,
},
Tok::String {
value: String::from("\\\""),
is_fstring: false,
},
Tok::String {
value: String::from("\t\r\n"),
is_fstring: false,
},
Tok::String {
value: String::from("\\g"),
is_fstring: false,
},
Tok::String {
value: String::from("raw\\'"),
is_fstring: false,
},
Tok::String {
value: String::from("Đ"),
is_fstring: false,
},
Tok::String {
value: String::from("\u{80}\u{0}a"),
is_fstring: false,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1830,7 +1727,7 @@ class Foo(A, B):
#[test] #[test]
fn test_escape_char_in_byte_literal() { fn test_escape_char_in_byte_literal() {
// backslash does not escape // backslash does not escape
let source = r##"b"omkmok\Xaa""##; let source = r#"b"omkmok\Xaa""#;
let tokens = lex_source(source); let tokens = lex_source(source);
let res = vec![111, 109, 107, 109, 111, 107, 92, 88, 97, 97]; let res = vec![111, 109, 107, 109, 111, 107, 92, 88, 97, 97];
assert_eq!(tokens, vec![Tok::Bytes { value: res }, Tok::Newline]); assert_eq!(tokens, vec![Tok::Bytes { value: res }, Tok::Newline]);
@ -1840,41 +1737,17 @@ class Foo(A, B):
fn test_raw_byte_literal() { fn test_raw_byte_literal() {
let source = r"rb'\x1z'"; let source = r"rb'\x1z'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"\\x1z".to_vec() }, Tok::Newline]);
tokens,
vec![
Tok::Bytes {
value: b"\\x1z".to_vec()
},
Tok::Newline
]
);
let source = r"rb'\\'"; let source = r"rb'\\'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"\\\\".to_vec() }, Tok::Newline]);
tokens,
vec![
Tok::Bytes {
value: b"\\\\".to_vec()
},
Tok::Newline
]
)
} }
#[test] #[test]
fn test_escape_octet() { fn test_escape_octet() {
let source = r##"b'\43a\4\1234'"##; let source = r"b'\43a\4\1234'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"#a\x04S4".to_vec() }, Tok::Newline]);
tokens,
vec![
Tok::Bytes {
value: b"#a\x04S4".to_vec()
},
Tok::Newline
]
)
} }
#[test] #[test]
@ -1883,13 +1756,7 @@ class Foo(A, B):
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![Tok::String { value: "\u{2002}".to_owned(), is_fstring: false }, Tok::Newline]
Tok::String { );
value: "\u{2002}".to_owned(),
is_fstring: false,
},
Tok::Newline
]
)
} }
} }

View File

@ -15,6 +15,24 @@
//! //!
//! ``` //! ```
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::enum_glob_use,
clippy::fn_params_excessive_bools,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
#[macro_use] #[macro_use]
extern crate log; extern crate log;
use lalrpop_util::lalrpop_mod; use lalrpop_util::lalrpop_mod;
@ -27,9 +45,16 @@ pub mod lexer;
pub mod mode; pub mod mode;
pub mod parser; pub mod parser;
lalrpop_mod!( lalrpop_mod!(
#[allow(clippy::all)] #[allow(
#[allow(unused)] future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
unused,
clippy::all,
clippy::pedantic
)]
python python
); );
pub mod token;
pub mod config_comment_helper; pub mod config_comment_helper;
pub mod token;

View File

@ -5,6 +5,7 @@
//! parse a whole program, a single statement, or a single //! parse a whole program, a single statement, or a single
//! expression. //! expression.
use nac3ast::Location;
use std::iter; use std::iter;
use crate::ast::{self, FileName}; use crate::ast::{self, FileName};
@ -63,7 +64,7 @@ pub fn parse_program(source: &str, file: FileName) -> Result<ast::Suite, ParseEr
/// ///
/// ``` /// ```
pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> { pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
parse(source, Mode::Expression, Default::default()).map(|top| match top { parse(source, Mode::Expression, FileName::default()).map(|top| match top {
ast::Mod::Expression { body } => *body, ast::Mod::Expression { body } => *body,
_ => unreachable!(), _ => unreachable!(),
}) })
@ -72,12 +73,10 @@ pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
// Parse a given source code // Parse a given source code
pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> { pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> {
let lxr = lexer::make_tokenizer(source, file); let lxr = lexer::make_tokenizer(source, file);
let marker_token = (Default::default(), mode.to_marker(), Default::default()); let marker_token = (Location::default(), mode.to_marker(), Location::default());
let tokenizer = iter::once(Ok(marker_token)).chain(lxr); let tokenizer = iter::once(Ok(marker_token)).chain(lxr);
python::TopParser::new() python::TopParser::new().parse(tokenizer).map_err(ParseError::from)
.parse(tokenizer)
.map_err(ParseError::from)
} }
#[cfg(test)] #[cfg(test)]
@ -86,42 +85,42 @@ mod tests {
#[test] #[test]
fn test_parse_empty() { fn test_parse_empty() {
let parse_ast = parse_program("", Default::default()).unwrap(); let parse_ast = parse_program("", FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_print_hello() { fn test_parse_print_hello() {
let source = String::from("print('Hello world')"); let source = String::from("print('Hello world')");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_print_2() { fn test_parse_print_2() {
let source = String::from("print('Hello world', 2)"); let source = String::from("print('Hello world', 2)");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_kwargs() { fn test_parse_kwargs() {
let source = String::from("my_func('positional', keyword=2)"); let source = String::from("my_func('positional', keyword=2)");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_if_elif_else() { fn test_parse_if_elif_else() {
let source = String::from("if 1: 10\nelif 2: 20\nelse: 30"); let source = String::from("if 1: 10\nelif 2: 20\nelse: 30");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_lambda() { fn test_parse_lambda() {
let source = "lambda x, y: x * y"; // lambda(x, y): x * y"; let source = "lambda x, y: x * y"; // lambda(x, y): x * y";
let parse_ast = parse_program(source, Default::default()).unwrap(); let parse_ast = parse_program(source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -129,7 +128,7 @@ mod tests {
fn test_parse_tuples() { fn test_parse_tuples() {
let source = "a, b = 4, 5"; let source = "a, b = 4, 5";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -140,7 +139,7 @@ class Foo(A, B):
pass pass
def method_with_default(self, arg='default'): def method_with_default(self, arg='default'):
pass"; pass";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -183,7 +182,7 @@ while i < 2: # nac3: 4
# nac3: if1 # nac3: if1
if 1: # nac3: if2 if 1: # nac3: if2
3"; 3";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -196,7 +195,7 @@ while test: # nac3: while3
# nac3: simple assign0 # nac3: simple assign0
a = 3 # nac3: simple assign1 a = 3 # nac3: simple assign1
"; ";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -215,7 +214,7 @@ if a: # nac3: small2
for i in a: # nac3: for1 for i in a: # nac3: for1
pass pass
"; ";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -224,6 +223,6 @@ for i in a: # nac3: for1
if a: # nac3: something if a: # nac3: something
a = 3 a = 3
"; ";
assert!(parse_program(source, Default::default()).is_err()); assert!(parse_program(source, FileName::default()).is_err());
} }
} }

View File

@ -1,7 +1,7 @@
//! Different token definitions. //! Different token definitions.
//! Loosely based on token.h from CPython source: //! Loosely based on token.h from CPython source:
use std::fmt::{self, Write};
use crate::ast; use crate::ast;
use std::fmt::{self, Write};
/// Python source code can be tokenized in a sequence of these tokens. /// Python source code can be tokenized in a sequence of these tokens.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -111,15 +111,23 @@ impl fmt::Display for Tok {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Tok::*; use Tok::*;
match self { match self {
Name { name } => write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name)), Name { name } => {
Int { value } => if *value != i128::MAX { write!(f, "'{}'", value) } else { write!(f, "'#OFL#'") }, write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name))
Float { value } => write!(f, "'{}'", value), }
Complex { real, imag } => write!(f, "{}j{}", real, imag), Int { value } => {
if *value == i128::MAX {
write!(f, "'#OFL#'")
} else {
write!(f, "'{value}'")
}
}
Float { value } => write!(f, "'{value}'"),
Complex { real, imag } => write!(f, "{real}j{imag}"),
String { value, is_fstring } => { String { value, is_fstring } => {
if *is_fstring { if *is_fstring {
write!(f, "f")? write!(f, "f")?;
} }
write!(f, "{:?}", value) write!(f, "{value:?}")
} }
Bytes { value } => { Bytes { value } => {
write!(f, "b\"")?; write!(f, "b\"")?;
@ -129,12 +137,16 @@ impl fmt::Display for Tok {
10 => f.write_str("\\n")?, 10 => f.write_str("\\n")?,
13 => f.write_str("\\r")?, 13 => f.write_str("\\r")?,
32..=126 => f.write_char(*i as char)?, 32..=126 => f.write_char(*i as char)?,
_ => write!(f, "\\x{:02x}", i)?, _ => write!(f, "\\x{i:02x}")?,
} }
} }
f.write_str("\"") f.write_str("\"")
} }
ConfigComment { content } => write!(f, "ConfigComment: '{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)), ConfigComment { content } => write!(
f,
"ConfigComment: '{}'",
ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)
),
Newline => f.write_str("Newline"), Newline => f.write_str("Newline"),
Indent => f.write_str("Indent"), Indent => f.write_str("Indent"),
Dedent => f.write_str("Dedent"), Dedent => f.write_str("Dedent"),

View File

@ -10,10 +10,10 @@ nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
[dependencies.clap] [dependencies.clap]
version = "4.4" version = "4.5"
features = ["derive"] features = ["derive"]
[dependencies.inkwell] [dependencies.inkwell]
version = "0.2" version = "0.4"
default-features = false default-features = false
features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"] features = ["llvm14-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]

4
nac3standalone/demo/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*.bc
*.ll
*.o
/demo

View File

@ -1,3 +1,4 @@
#include <inttypes.h>
#include <math.h> #include <math.h>
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h> #include <stdint.h>
@ -5,13 +6,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#if __SIZEOF_POINTER__ == 8 #define usize size_t
#define usize uint64_t
#elif __SIZEOF_POINTER__ == 4
#define usize uint32_t
#else
#error "Unsupported platform - Platform is not 32-bit or 64-bit"
#endif
double dbl_nan(void) { double dbl_nan(void) {
return NAN; return NAN;
@ -26,19 +21,19 @@ void output_bool(bool x) {
} }
void output_int32(int32_t x) { void output_int32(int32_t x) {
printf("%d\n", x); printf("%"PRId32"\n", x);
} }
void output_int64(int64_t x) { void output_int64(int64_t x) {
printf("%lld\n", x); printf("%"PRId64"\n", x);
} }
void output_uint32(uint32_t x) { void output_uint32(uint32_t x) {
printf("%u\n", x); printf("%"PRIu32"\n", x);
} }
void output_uint64(uint64_t x) { void output_uint64(uint64_t x) {
printf("%llu\n", x); printf("%"PRIu64"\n", x);
} }
void output_float64(double x) { void output_float64(double x) {
@ -49,6 +44,15 @@ void output_float64(double x) {
} }
} }
void output_range(int32_t range[3]) {
printf("range(");
printf("%d, %d", range[0], range[1]);
if (range[2] != 1) {
printf(", %d", range[2]);
}
puts(")");
}
void output_asciiart(int32_t x) { void output_asciiart(int32_t x) {
static const char *chars = " .,-:;i+hHM$*#@ "; static const char *chars = " .,-:;i+hHM$*#@ ";
if (x < 0) { if (x < 0) {
@ -84,6 +88,10 @@ void output_str(struct cslice *slice) {
for (usize i = 0; i < slice->len; ++i) { for (usize i = 0; i < slice->len; ++i) {
putchar(data[i]); putchar(data[i]);
} }
}
void output_strln(struct cslice *slice) {
output_str(slice);
putchar('\n'); putchar('\n');
} }
@ -94,13 +102,13 @@ uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) {
} }
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) { uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
printf("__nac3_personality(state: %u, exception_object: %u, context: %u\n", state, exception_object, context); printf("__nac3_personality(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }
uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) { uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) {
printf("__nac3_raise(state: %u, exception_object: %u, context: %u\n", state, exception_object, context); printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
exit(101); exit(101);
__builtin_unreachable(); __builtin_unreachable();
} }

View File

@ -5,11 +5,12 @@ import importlib.util
import importlib.machinery import importlib.machinery
import math import math
import numpy as np import numpy as np
import numpy.typing as npt
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
from scipy import special from scipy import special
from typing import TypeVar, Generic from typing import TypeVar, Generic, Literal, Union
T = TypeVar('T') T = TypeVar('T')
class Option(Generic[T]): class Option(Generic[T]):
@ -50,12 +51,46 @@ class _ConstGenericMarker:
def ConstGeneric(name, constraint): def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint) return TypeVar(name, _ConstGenericMarker, constraint)
N = TypeVar("N", bound=np.uint64)
class _NDArrayDummy(Generic[T, N]):
pass
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
def _bool(x):
if isinstance(x, np.ndarray):
return np.bool_(x)
else:
return bool(x)
def _float(x):
if isinstance(x, np.ndarray):
return np.float_(x)
else:
return float(x)
def round_away_zero(x): def round_away_zero(x):
if isinstance(x, np.ndarray):
return np.vectorize(round_away_zero)(x)
else:
if x >= 0.0: if x >= 0.0:
return math.floor(x + 0.5) return math.floor(x + 0.5)
else: else:
return math.ceil(x - 0.5) return math.ceil(x - 0.5)
def _floor(x):
if isinstance(x, np.ndarray):
return np.vectorize(_floor)(x)
else:
return math.floor(x)
def _ceil(x):
if isinstance(x, np.ndarray):
return np.vectorize(_ceil)(x)
else:
return math.ceil(x)
def patch(module): def patch(module):
def dbl_nan(): def dbl_nan():
return np.nan return np.nan
@ -72,6 +107,9 @@ def patch(module):
def output_float(x): def output_float(x):
print("%f" % x) print("%f" % x)
def output_strln(x):
print(x, end='')
def dbg_stack_address(_): def dbg_stack_address(_):
return 0 return 0
@ -85,6 +123,8 @@ def patch(module):
return output_asciiart return output_asciiart
elif name == "output_float64": elif name == "output_float64":
return output_float return output_float
elif name == "output_str":
return output_strln
elif name in { elif name in {
"output_bool", "output_bool",
"output_int32", "output_int32",
@ -92,7 +132,8 @@ def patch(module):
"output_int32_list", "output_int32_list",
"output_uint32", "output_uint32",
"output_uint64", "output_uint64",
"output_str", "output_strln",
"output_range",
}: }:
return print return print
elif name == "dbg_stack_address": elif name == "dbg_stack_address":
@ -104,9 +145,12 @@ def patch(module):
module.int64 = int64 module.int64 = int64
module.uint32 = uint32 module.uint32 = uint32
module.uint64 = uint64 module.uint64 = uint64
module.bool = _bool
module.float = _float
module.TypeVar = TypeVar module.TypeVar = TypeVar
module.ConstGeneric = ConstGeneric module.ConstGeneric = ConstGeneric
module.Generic = Generic module.Generic = Generic
module.Literal = Literal
module.extern = extern module.extern = extern
module.Option = Option module.Option = Option
module.Some = Some module.Some = Some
@ -116,16 +160,31 @@ def patch(module):
module.round = round_away_zero module.round = round_away_zero
module.round64 = round_away_zero module.round64 = round_away_zero
module.np_round = np.round module.np_round = np.round
module.floor = math.floor module.floor = _floor
module.floor64 = math.floor module.floor64 = _floor
module.np_floor = np.floor module.np_floor = np.floor
module.ceil = math.ceil module.ceil = _ceil
module.ceil64 = math.ceil module.ceil64 = _ceil
module.np_ceil = np.ceil module.np_ceil = np.ceil
# NumPy ndarray functions
module.ndarray = NDArray
module.np_ndarray = np.ndarray
module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
module.np_array = np.array
# NumPy Math functions # NumPy Math functions
module.np_isnan = np.isnan module.np_isnan = np.isnan
module.np_isinf = np.isinf module.np_isinf = np.isinf
module.np_min = np.min
module.np_minimum = np.minimum
module.np_max = np.max
module.np_maximum = np.maximum
module.np_sin = np.sin module.np_sin = np.sin
module.np_cos = np.cos module.np_cos = np.cos
module.np_exp = np.exp module.np_exp = np.exp
@ -165,6 +224,14 @@ def patch(module):
module.sp_spec_j0 = special.j0 module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1 module.sp_spec_j1 = special.j1
# NumPy NDArray Functions
module.np_ndarray = np.ndarray
module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

View File

@ -10,6 +10,10 @@ fi
declare -a nac3args declare -a nac3args
while [ $# -ge 1 ]; do while [ $# -ge 1 ]; do
case "$1" in case "$1" in
--help)
echo "Usage: run_demo.sh [--help] [--out OUTFILE] [--lli] [--debug] -- [NAC3ARGS...]"
exit
;;
--out) --out)
shift shift
outfile="$1" outfile="$1"
@ -17,14 +21,28 @@ while [ $# -ge 1 ]; do
--lli) --lli)
use_lli=1 use_lli=1
;; ;;
--debug)
debug=1
;;
--)
shift
break
;;
*) *)
nac3args+=("$1") break
;; ;;
esac esac
shift shift
done done
if [ -e ../../target/release/nac3standalone ]; then while [ $# -ge 1 ]; do
nac3args+=("$1")
shift
done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone
elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
else else
# used by Nix builds # used by Nix builds

View File

@ -6,6 +6,10 @@ def output_int32(x: int32):
def output_int64(x: int64): def output_int64(x: int64):
... ...
@extern
def output_strln(x: str):
...
class B: class B:
b: int32 b: int32
@ -23,10 +27,14 @@ class A:
def get_a(self) -> int32: def get_a(self) -> int32:
return self.a return self.a
def get_b(self) -> B: # def get_b(self) -> B:
return self.b # return self.b
class Initless:
def foo(self):
output_strln("hello")
def run() -> int32: def run() -> int32:
a = A(10) a = A(10)
output_int32(a.a) output_int32(a.a)
@ -35,4 +43,8 @@ def run() -> int32:
output_int32(a.a) output_int32(a.a)
output_int32(a.get_a()) output_int32(a.get_a())
# output_int32(a.get_b().b) FIXME: NAC3 prints garbage # output_int32(a.get_b().b) FIXME: NAC3 prints garbage
initless = Initless()
initless.foo()
return 0 return 0

View File

@ -16,28 +16,28 @@ class HybridGenericClass2(Generic[A, T]):
class HybridGenericClass3(Generic[T, A, B]): class HybridGenericClass3(Generic[T, A, B]):
pass pass
def make_generic_2() -> ConstGenericClass[2]: def make_generic_2() -> ConstGenericClass[Literal[2]]:
return ... return ...
def make_generic2_1_2() -> ConstGeneric2Class[1, 2]: def make_generic2_1_2() -> ConstGeneric2Class[Literal[1], Literal[2]]:
return ... return ...
def make_hybrid_class_2_int32() -> HybridGenericClass2[2, int32]: def make_hybrid_class_2_int32() -> HybridGenericClass2[Literal[2], int32]:
return ... return ...
def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, 0, 1]: def make_hybrid_class_i32_0_1() -> HybridGenericClass3[int32, Literal[0], Literal[1]]:
return ... return ...
def consume_generic_2(instance: ConstGenericClass[2]): def consume_generic_2(instance: ConstGenericClass[Literal[2]]):
pass pass
def consume_generic2_1_2(instance: ConstGeneric2Class[1, 2]): def consume_generic2_1_2(instance: ConstGeneric2Class[Literal[1], Literal[2]]):
pass pass
def consume_hybrid_class_2_i32(instance: HybridGenericClass2[2, int32]): def consume_hybrid_class_2_i32(instance: HybridGenericClass2[Literal[2], int32]):
pass pass
def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, 0, 1]): def consume_hybrid_class_i32_0_1(instance: HybridGenericClass3[int32, Literal[0], Literal[1]]):
pass pass
def f(): def f():

View File

@ -22,6 +22,10 @@ def output_uint64(x: uint64):
def output_float64(x: float): def output_float64(x: float):
... ...
@extern
def output_range(x: range):
...
@extern @extern
def output_int32_list(x: list[int32]): def output_int32_list(x: list[int32]):
... ...
@ -34,6 +38,10 @@ def output_asciiart(x: int32):
def output_str(x: str): def output_str(x: str):
... ...
@extern
def output_strln(x: str):
...
def test_output_bool(): def test_output_bool():
output_bool(True) output_bool(True)
output_bool(False) output_bool(False)
@ -59,6 +67,15 @@ def test_output_float64():
output_float64(16.25) output_float64(16.25)
output_float64(-16.25) output_float64(-16.25)
def test_output_range():
r = range(1, 100, 5)
output_int32(r.start)
output_int32(r.stop)
output_int32(r.step)
output_range(range(10))
output_range(range(1, 10))
output_range(range(1, 10, 2))
def test_output_asciiart(): def test_output_asciiart():
for i in range(17): for i in range(17):
output_asciiart(i) output_asciiart(i)
@ -68,7 +85,8 @@ def test_output_int32_list():
output_int32_list([0, 1, 3, 5, 10]) output_int32_list([0, 1, 3, 5, 10])
def test_output_str_family(): def test_output_str_family():
output_str("hello world") output_str("hello")
output_strln(" world")
def run() -> int32: def run() -> int32:
test_output_bool() test_output_bool()
@ -77,6 +95,7 @@ def run() -> int32:
test_output_uint32() test_output_uint32()
test_output_uint64() test_output_uint64()
test_output_float64() test_output_float64()
test_output_range()
test_output_asciiart() test_output_asciiart()
test_output_int32_list() test_output_int32_list()
test_output_str_family() test_output_str_family()

View File

@ -1,3 +1,7 @@
@extern
def output_bool(x: bool):
...
@extern @extern
def output_int32_list(x: list[int32]): def output_int32_list(x: list[int32]):
... ...
@ -30,6 +34,32 @@ def run() -> int32:
get_list_slice() get_list_slice()
list_slice_assignment() list_slice_assignment()
output_int32_list([1, 2, 3] + [4, 5, 6])
output_int32_list([1, 2, 3] * 3)
output_bool([] == [])
output_bool([0] == [])
output_bool([0] == [0])
output_bool([0, 1] == [0])
output_bool([0, 1] == [0, 1])
output_bool([] != [])
output_bool([0] != [])
output_bool([0] != [0])
output_bool([0] != [0, 1])
output_bool([0, 1] != [0, 1])
output_bool([] == [] == [])
output_bool([0] == [0] == [0])
output_bool([0, 1] == [0] == [0, 1])
output_bool([0, 1] == [0, 1] == [0])
output_bool([0] == [0, 1] == [0, 1])
output_bool([0, 1] == [0, 1] == [0, 1])
output_bool([] != [] != [])
output_bool([0] != [0] != [0])
output_bool([0, 1] != [0] != [0, 1])
output_bool([0, 1] != [0, 1] != [0])
output_bool([0] != [0, 1] != [0, 1])
output_bool([0, 1] != [0, 1] != [0, 1])
return 0 return 0
def get_list_slice(): def get_list_slice():

View File

@ -23,11 +23,12 @@ def run() -> int32:
output_int32(x) output_int32(x)
output_str(" * ") output_str(" * ")
output_float64(n / x) output_float64(n / x)
output_str("\n")
except: # Assume this is intended to catch x == 0 except: # Assume this is intended to catch x == 0
break break
else: else:
# loop fell through without finding a factor # loop fell through without finding a factor
output_int32(n) output_int32(n)
output_str(" is a prime number") output_str(" is a prime number\n")
return 0 return 0

View File

@ -37,7 +37,7 @@ def test_round64():
output_int64(round64(x)) output_int64(round64(x))
def test_np_round(): def test_np_round():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan(), 0.0, -0.0, 1.6, 1.4, -1.4, -1.6]:
output_float64(np_round(x)) output_float64(np_round(x))
def test_np_isnan(): def test_np_isnan():
@ -162,7 +162,7 @@ def test_np_expm1():
def test_np_cbrt(): def test_np_cbrt():
for x in [1.0, 8.0, 27.0, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [1.0, 8.0, 27.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(np_expm1(x)) output_float64(np_cbrt(x))
def test_sp_spec_erf(): def test_sp_spec_erf():
for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, dbl_inf(), -dbl_inf(), dbl_nan()]: for x in [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, dbl_inf(), -dbl_inf(), dbl_nan()]:

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,9 @@
from __future__ import annotations from __future__ import annotations
@extern
def output_bool(x: bool):
...
@extern @extern
def output_int32(x: int32): def output_int32(x: int32):
... ...
@ -17,14 +21,27 @@ def output_float64(x: float):
... ...
def run() -> int32: def run() -> int32:
test_bool()
test_int32() test_int32()
test_uint32() test_uint32()
test_int64() test_int64()
test_uint64() test_uint64()
test_A() # test_A()
test_B() # test_B()
return 0 return 0
def test_bool():
t = True
f = False
output_bool(not t)
output_bool(not f)
output_int32(~t)
output_int32(~f)
output_int32(+t)
output_int32(+f)
output_int32(-t)
output_int32(-f)
def test_int32(): def test_int32():
a = 17 a = 17
b = 3 b = 3
@ -173,91 +190,92 @@ def test_uint64():
a >>= uint32(b) a >>= uint32(b)
output_uint64(a) output_uint64(a)
class A: # FIXME Fix returning objects of non-primitive types; Currently this is disabled in the function checker
a: int32 # class A:
def __init__(self, a: int32): # a: int32
self.a = a # def __init__(self, a: int32):
# self.a = a
def __add__(self, other: A) -> A: #
output_int32(self.a + other.a) # def __add__(self, other: A) -> A:
return A(self.a + other.a) # output_int32(self.a + other.a)
# return A(self.a + other.a)
def __sub__(self, other: A) -> A: #
output_int32(self.a - other.a) # def __sub__(self, other: A) -> A:
return A(self.a - other.a) # output_int32(self.a - other.a)
# return A(self.a - other.a)
def test_A(): #
a = A(17) # def test_A():
b = A(3) # a = A(17)
# b = A(3)
c = a + b #
# fail due to alloca in __add__ function # c = a + b
# output_int32(c.a) # # fail due to alloca in __add__ function
# # output_int32(c.a)
a += b #
# fail due to alloca in __add__ function # a += b
# output_int32(a.a) # # fail due to alloca in __add__ function
# # output_int32(a.a)
a = A(17) #
b = A(3) # a = A(17)
d = a - b # b = A(3)
# fail due to alloca in __add__ function # d = a - b
# output_int32(c.a) # # fail due to alloca in __add__ function
# # output_int32(c.a)
a -= b #
# fail due to alloca in __add__ function # a -= b
# output_int32(a.a) # # fail due to alloca in __add__ function
# # output_int32(a.a)
a = A(17) #
b = A(3) # a = A(17)
a.__add__(b) # b = A(3)
a.__sub__(b) # a.__add__(b)
# a.__sub__(b)
#
class B: #
a: int32 # class B:
def __init__(self, a: int32): # a: int32
self.a = a # def __init__(self, a: int32):
# self.a = a
def __add__(self, other: B) -> B: #
output_int32(self.a + other.a) # def __add__(self, other: B) -> B:
return B(self.a + other.a) # output_int32(self.a + other.a)
# return B(self.a + other.a)
def __sub__(self, other: B) -> B: #
output_int32(self.a - other.a) # def __sub__(self, other: B) -> B:
return B(self.a - other.a) # output_int32(self.a - other.a)
# return B(self.a - other.a)
def __iadd__(self, other: B) -> B: #
output_int32(self.a + other.a + 24) # def __iadd__(self, other: B) -> B:
return B(self.a + other.a + 24) # output_int32(self.a + other.a + 24)
# return B(self.a + other.a + 24)
def __isub__(self, other: B) -> B: #
output_int32(self.a - other.a - 24) # def __isub__(self, other: B) -> B:
return B(self.a - other.a - 24) # output_int32(self.a - other.a - 24)
# return B(self.a - other.a - 24)
def test_B(): #
a = B(17) # def test_B():
b = B(3) # a = B(17)
# b = B(3)
c = a + b #
# fail due to alloca in __add__ function # c = a + b
# output_int32(c.a) # # fail due to alloca in __add__ function
# # output_int32(c.a)
a += b #
# fail due to alloca in __add__ function # a += b
# output_int32(a.a) # # fail due to alloca in __add__ function
# # output_int32(a.a)
a = B(17) #
b = B(3) # a = B(17)
d = a - b # b = B(3)
# fail due to alloca in __add__ function # d = a - b
# output_int32(c.a) # # fail due to alloca in __add__ function
# # output_int32(c.a)
a -= b #
# fail due to alloca in __add__ function # a -= b
# output_int32(a.a) # # fail due to alloca in __add__ function
# # output_int32(a.a)
a = B(17) #
b = B(3) # a = B(17)
a.__add__(b) # b = B(3)
a.__sub__(b) # a.__add__(b)
# a.__sub__(b)

View File

@ -9,8 +9,8 @@ use nac3core::{
}; };
use nac3parser::ast::{self, StrRef}; use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, sync::Arc};
use std::collections::HashSet; use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal { pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>, pub id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -63,8 +63,12 @@ impl SymbolResolver for Resolver {
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).copied() self.0
.ok_or_else(|| HashSet::from(["Undefined identifier".to_string()])) .id_to_def
.lock()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")]))
} }
fn get_string_id(&self, s: &str) -> i32 { fn get_string_id(&self, s: &str) -> i32 {
@ -72,7 +76,8 @@ impl SymbolResolver for Resolver {
if let Some(id) = str_store.get(s) { if let Some(id) = str_store.get(s) {
*id *id
} else { } else {
let id = str_store.len() as i32; let id = i32::try_from(str_store.len())
.expect("Symbol resolver string store size exceeds max capacity (i32::MAX)");
str_store.insert(s.to_string(), id); str_store.insert(s.to_string(), id);
id id
} }

View File

@ -1,14 +1,22 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(clippy::too_many_lines, clippy::wildcard_imports)]
use clap::Parser; use clap::Parser;
use inkwell::{ use inkwell::{
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use std::collections::HashSet; use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
@ -17,12 +25,14 @@ use nac3core::{
}, },
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{ toplevel::{
composer::TopLevelComposer, helper::parse_parameter_default_value, type_annotation::*, composer::{ComposerConfig, TopLevelComposer},
helper::parse_parameter_default_value,
type_annotation::*,
TopLevelDef, TopLevelDef,
}, },
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{FunSignature, Type, Unifier}, typedef::{FunSignature, Type, Unifier, VarMap},
}, },
}; };
use nac3parser::{ use nac3parser::{
@ -32,10 +42,6 @@ use nac3parser::{
mod basic_symbol_resolver; mod basic_symbol_resolver;
use basic_symbol_resolver::*; use basic_symbol_resolver::*;
use nac3core::toplevel::composer::ComposerConfig;
#[cfg(test)]
mod test;
/// Command-line argument parser definition. /// Command-line argument parser definition.
#[derive(Parser)] #[derive(Parser)]
@ -80,19 +86,18 @@ fn handle_typevar_definition(
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
let ExprKind::Call { func, args, .. } = &var.node else { let ExprKind::Call { func, args, .. } = &var.node else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"expression {var:?} cannot be handled as a generic parameter in global scope" "expression {var:?} cannot be handled as a generic parameter in global scope"
), )]));
]))
}; };
match &func.node { match &func.node {
ExprKind::Name { id, .. } if id == &"TypeVar".into() => { ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node), "Expected string constant for first parameter of `TypeVar`, got {:?}",
])) &args[0].node
)]));
}; };
let generic_name: StrRef = ty_name.to_string().into(); let generic_name: StrRef = ty_name.to_string().into();
@ -106,39 +111,35 @@ fn handle_typevar_definition(
unifier, unifier,
primitives, primitives,
x, x,
HashMap::default(), HashMap::new(),
None,
)?; )?;
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)
def_list, unifier, &ty, &mut None
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let loc = func.location; let loc = func.location;
if constraints.len() == 1 { if constraints.len() == 1 {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("A single constraint is not allowed (at {loc})"), "A single constraint is not allowed (at {loc})"
])) )]));
} }
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0) Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).ty)
} }
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => { ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 { if args.len() != 2 {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()), "Expected 2 arguments for `ConstGeneric`, got {}",
])) args.len()
)]));
} }
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}", "Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node &args[0].node
), )]));
]))
}; };
let generic_name: StrRef = ty_name.to_string().into(); let generic_name: StrRef = ty_name.to_string().into();
@ -148,22 +149,18 @@ fn handle_typevar_definition(
unifier, unifier,
primitives, primitives,
&args[1], &args[1],
HashMap::default(), HashMap::new(),
None,
)?;
let constraint = get_type_from_type_annotation_kinds(
def_list, unifier, &ty, &mut None
)?; )?;
let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?;
let loc = func.location; let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0) Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).ty)
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!(
"expression {var:?} cannot be handled as a generic parameter in global scope" "expression {var:?} cannot be handled as a generic parameter in global scope"
), )])),
]))
} }
} }
@ -179,18 +176,12 @@ fn handle_assignment_pattern(
if targets.len() == 1 { if targets.len() == 1 {
match &targets[0].node { match &targets[0].node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if let Ok(var) = handle_typevar_definition( if let Ok(var) =
value, handle_typevar_definition(value, resolver, def_list, unifier, primitives)
resolver, {
def_list,
unifier,
primitives,
) {
internal_resolver.add_id_type(*id, var); internal_resolver.add_id_type(*id, var);
Ok(()) Ok(())
} else if let Ok(val) = } else if let Ok(val) = parse_parameter_default_value(value, resolver) {
parse_parameter_default_value(value, resolver)
{
internal_resolver.add_module_global(*id, val); internal_resolver.add_module_global(*id, val);
Ok(()) Ok(())
} else { } else {
@ -242,25 +233,17 @@ fn handle_assignment_pattern(
)) ))
} }
} }
_ => Err(format!( _ => Err(format!("unpack of this expression is not supported at {}", value.location)),
"unpack of this expression is not supported at {}",
value.location
)),
} }
} }
} }
fn main() { fn main() {
const SIZE_T: u32 = usize::BITS;
let cli = CommandLineArgs::parse(); let cli = CommandLineArgs::parse();
let CommandLineArgs { let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
file_name, cli;
threads,
opt_level,
emit_llvm,
triple,
mcpu,
target_features,
} = cli;
Target::initialize_all(&InitializationConfig::default()); Target::initialize_all(&InitializationConfig::default());
@ -272,11 +255,9 @@ fn main() {
let target_features = target_features.unwrap_or_default(); let target_features = target_features.unwrap_or_default();
let threads = if is_multithreaded() { let threads = if is_multithreaded() {
if threads == 0 { if threads == 0 {
std::thread::available_parallelism() std::thread::available_parallelism().map(NonZeroUsize::get).unwrap_or(1usize)
.map(|threads| threads.get() as u32)
.unwrap_or(1u32)
} else { } else {
threads threads as usize
} }
} else { } else {
if threads != 1 { if threads != 1 {
@ -300,9 +281,9 @@ fn main() {
} }
}; };
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; let primitive: PrimitiveStore = TopLevelComposer::make_primitives(SIZE_T).0;
let (mut composer, builtins_def, builtins_ty) = let (mut composer, builtins_def, builtins_ty) =
TopLevelComposer::new(vec![], ComposerConfig::default()); TopLevelComposer::new(vec![], ComposerConfig::default(), SIZE_T);
let internal_resolver: Arc<ResolverInternal> = ResolverInternal { let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
id_to_type: builtins_ty.into(), id_to_type: builtins_ty.into(),
@ -310,7 +291,8 @@ fn main() {
class_names: Mutex::default(), class_names: Mutex::default(),
module_globals: Mutex::default(), module_globals: Mutex::default(),
str_store: Mutex::default(), str_store: Mutex::default(),
}.into(); }
.into();
let resolver = let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
@ -334,13 +316,16 @@ fn main() {
eprintln!("{err}"); eprintln!("{err}");
return; return;
} }
}, }
// allow (and ignore) "from __future__ import annotations" // allow (and ignore) "from __future__ import annotations"
StmtKind::ImportFrom { module, names, .. } StmtKind::ImportFrom { module, names, .. }
if module == &Some("__future__".into()) && names.len() == 1 && names[0].name == "annotations".into() => (), if module == &Some("__future__".into())
&& names.len() == 1
&& names[0].name == "annotations".into() => {}
_ => { _ => {
let (name, def_id, ty) = let (name, def_id, ty) = composer
composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true).unwrap(); .register_top_level(stmt, Some(resolver.clone()), "__main__", true)
.unwrap();
internal_resolver.add_id_def(name, def_id); internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty); internal_resolver.add_id_type(name, ty);
@ -349,13 +334,25 @@ fn main() {
} }
} }
let signature = FunSignature { args: vec![], ret: primitive.int32, vars: HashMap::new() }; let signature = FunSignature { args: vec![], ret: primitive.int32, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache); let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
composer.start_analysis(true).unwrap(); if let Err(errors) = composer.start_analysis(true) {
let error_count = errors.len();
eprintln!("{error_count} error(s) occurred during top level analysis.");
for (error_i, error) in errors.iter().enumerate() {
let error_num = error_i + 1;
eprintln!("=========== ERROR {error_num}/{error_count} ============");
eprintln!("{error}");
}
eprintln!("==================================");
panic!("top level analysis failed");
}
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -366,7 +363,8 @@ fn main() {
.unwrap_or_else(|_| panic!("cannot find run() entry point")) .unwrap_or_else(|_| panic!("cannot find run() entry point"))
.0] .0]
.write(); .write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance else { let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance
else {
unreachable!() unreachable!()
}; };
instance_to_symbol.insert(String::new(), "run".to_string()); instance_to_symbol.insert(String::new(), "run".to_string());
@ -405,7 +403,7 @@ fn main() {
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
}))); })));
let threads = (0..threads) let threads = (0..threads)
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), 64))) .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), SIZE_T)))
.collect(); .collect();
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task); registry.add_task(task);
@ -446,7 +444,8 @@ fn main() {
function_iter = func.get_next_function(); function_iter = func.get_next_function();
} }
let target_machine = llvm_options.target let target_machine = llvm_options
.target
.create_target_machine(llvm_options.opt_level) .create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine"); .expect("couldn't create target machine");

View File

@ -1,148 +0,0 @@
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use std::process::{Command, Output};
use parking_lot::Mutex;
/// [Mutex] for enforcing only a single `./run_demo.sh` is executing.
///
/// This is necessary as the command will generate temporary files (specifically `*.o`, `*.bc`, and
/// `demo`) in the filesystem.
static RUN_EXEC_MTX: Mutex<()> = Mutex::new(());
/// Dummy struct for automatically removing temporary files after test case execution completes.
struct CleanupDemoTempFiles;
impl Default for CleanupDemoTempFiles {
fn default() -> Self {
CleanupDemoTempFiles {}
}
}
impl Drop for CleanupDemoTempFiles {
fn drop(&mut self) {
get_demo_dir().read_dir()
.unwrap()
.into_iter()
.map(|dir_entry| dir_entry.unwrap().path().into_boxed_path())
.filter(|p| {
let ext = p.extension()
.and_then(|ext| ext.to_str())
.unwrap_or_default();
let file_name = p.file_name()
.and_then(|ext| ext.to_str())
.unwrap_or_default();
["bc", "o"].contains(&ext) || file_name == "demo"
})
.for_each(|p| fs::remove_file(p).unwrap())
}
}
/// Returns the absolute [Path] to the `demo` directory.
fn get_demo_dir() -> PathBuf {
std::env::current_dir()
.map(|p| p.join("demo"))
.and_then(|p| p.canonicalize())
.unwrap()
}
/// Collects all Python tests from `demo/src`.
fn collect_demo_tests() -> Vec<PathBuf> {
get_demo_dir()
.join("src")
.read_dir()
.unwrap()
.into_iter()
.filter_map(|dir_entry| {
let p = dir_entry.unwrap().path();
if p.is_file() && p.extension().and_then(|ext| ext.to_str()).unwrap_or_default() == "py" {
Some(p)
} else {
None
}
})
.collect()
}
/// Runs the interpreter on the `file` as if by invoking `interpret_demo.py`.
fn run_interpreter(file: &Path) -> io::Result<Output> {
Command::new("./interpret_demo.py")
.arg(file.to_str().unwrap())
.current_dir(get_demo_dir())
.output()
}
/// Runs NAC3 on the `file` as if by invoking `run_demo.sh`.
fn run_demo(file: &Path, nac3_args: &[&str]) -> io::Result<Output> {
let _lk = RUN_EXEC_MTX.lock();
Command::new("./run_demo.sh")
.args(nac3_args)
.arg(file.to_str().unwrap())
.current_dir(get_demo_dir())
.output()
}
#[test]
fn test_demos_optimized() {
let _cleanup = CleanupDemoTempFiles::default();
let tests = collect_demo_tests();
for test in tests.iter() {
let py_interpreted= run_interpreter(test).unwrap();
let py_stdout = String::from_utf8(py_interpreted.stdout).unwrap();
let executed = run_demo(test, &[]).unwrap();
let exec_stdout = String::from_utf8(executed.stdout).unwrap();
assert_eq!(py_stdout, exec_stdout, "Execution result mismatch for {}", test.file_name().and_then(|p| p.to_str()).unwrap());
let bc_interpreted = run_demo(test, &["--lli"]).unwrap();
let bc_stdout = String::from_utf8(bc_interpreted.stdout).unwrap();
assert_eq!(py_stdout, bc_stdout, "Execution result mismatch for {}", test.file_name().and_then(|p| p.to_str()).unwrap());
}
}
#[test]
fn test_demos_unoptimized() {
let _cleanup = CleanupDemoTempFiles::default();
let tests = collect_demo_tests();
for test in tests.iter() {
let py_interpreted= run_interpreter(test).unwrap();
let py_stdout = String::from_utf8(py_interpreted.stdout).unwrap();
let executed = run_demo(test, &["-O0"]).unwrap();
let exec_stdout = String::from_utf8(executed.stdout).unwrap();
assert_eq!(py_stdout, exec_stdout, "Execution result differs for {}", test.file_name().and_then(|p| p.to_str()).unwrap());
let bc_interpreted= run_demo(test, &["--lli", "-O0"]).unwrap();
let bc_stdout = String::from_utf8(bc_interpreted.stdout).unwrap();
assert_eq!(py_stdout, bc_stdout, "Execution result mismatch for {}", test.file_name().and_then(|p| p.to_str()).unwrap());
}
}
#[test]
fn test_demos_multithreaded() {
let _cleanup = CleanupDemoTempFiles::default();
let tests = collect_demo_tests();
for test in tests.iter() {
let py_interpreted= run_interpreter(test).unwrap();
let py_stdout = String::from_utf8(py_interpreted.stdout).unwrap();
let executed = run_demo(test, &["-T4"]).unwrap();
let exec_stdout = String::from_utf8(executed.stdout).unwrap();
assert_eq!(py_stdout, exec_stdout, "Execution result differs for {}", test.file_name().and_then(|p| p.to_str()).unwrap());
let bc_interpreted= run_demo(test, &["--lli", "-T4"]).unwrap();
let bc_stdout = String::from_utf8(bc_interpreted.stdout).unwrap();
assert_eq!(py_stdout, bc_stdout, "Execution result mismatch for {}", test.file_name().and_then(|p| p.to_str()).unwrap());
}
}

View File

@ -81,6 +81,7 @@ in rec {
'' ''
mkdir -p $out/bin mkdir -p $out/bin
ln -s ${llvm-nac3}/bin/clang.exe $out/bin/clang-irrt.exe ln -s ${llvm-nac3}/bin/clang.exe $out/bin/clang-irrt.exe
ln -s ${llvm-nac3}/bin/clang.exe $out/bin/clang-irrt-test.exe
ln -s ${llvm-nac3}/bin/llvm-as.exe $out/bin/llvm-as-irrt.exe ln -s ${llvm-nac3}/bin/llvm-as.exe $out/bin/llvm-as-irrt.exe
''; '';
nac3artiq = pkgs.rustPlatform.buildRustPackage { nac3artiq = pkgs.rustPlatform.buildRustPackage {

View File

@ -1,123 +1,123 @@
{ pkgs } : [ { pkgs } : [
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libffi-3.4.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-18.1.8-1-any.pkg.tar.zst";
sha256 = "0mws1g7w11riczc168x7kzb8nl74iry4bzkb72nspw1vmlblxfy6"; sha256 = "1v8zkfcbf1ga2ndpd1j0dwv5s1rassxs2b5pjhcsmqwjcvczba1m";
name = "mingw-w64-clang-x86_64-libffi-3.4.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libunwind-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunwind-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-18.1.8-1-any.pkg.tar.zst";
sha256 = "1748ncih1zj4vnj9c0gcp5rf21gm2pn9xdy2hrigp2pvfchrnmwn"; sha256 = "0mfd8wrmgx12j5gf354j7pk1l3lg9ykxvq75xdk3jipsr6hbn846";
name = "mingw-w64-clang-x86_64-libunwind-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libc++-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libc++-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libffi-3.4.6-1-any.pkg.tar.zst";
sha256 = "00waw5qmpycib7xhk6aywqgqg8y26q2xbm8m4za7mpy04g2alav4"; sha256 = "1q6gms980985bp087rnnpvz2fwfakgm5266izfk3b1mbp620s1yv";
name = "mingw-w64-clang-x86_64-libc++-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libffi-3.4.6-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-zlib-1.3-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst";
sha256 = "0rfwz7czvwa8clickjw1kd114miyifxf1s2v9mxffkaivm45f62v"; sha256 = "1g2bkhgf60dywccxw911ydyigf3m25yqfh81m5099swr7mjsmzyf";
name = "mingw-w64-clang-x86_64-zlib-1.3-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libiconv-1.17-4-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libiconv-1.17-3-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst";
sha256 = "1hxmdgivb86h7wz9hcp0had99ngv157w1fbjg7cgy068zv787m8w"; sha256 = "0ll6ci6d3mc7g04q0xixjc209bh8r874dqbczgns69jsad3wg6mi";
name = "mingw-w64-clang-x86_64-libiconv-1.17-3-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-gettext-runtime-0.22.5-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-expat-2.5.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-xz-5.6.2-2-any.pkg.tar.zst";
sha256 = "1gi9ckh48k64gras307f6pf5y558hj80izlxri8cnwvzzmmra8dg"; sha256 = "0phb9hwqksk1rg29yhwlc7si78zav19c2kac0i841pc7mc2n9gzx";
name = "mingw-w64-clang-x86_64-expat-2.5.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-xz-5.6.2-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-gettext-0.22.4-3-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-zlib-1.3.1-1-any.pkg.tar.zst";
sha256 = "0i2gbhzvcy8cq7gnd9zsjw47jmn8gsq5pam5j3yn84jn578f3mfi"; sha256 = "06i9xjsskf4ddb2ph4h31md5c7imj9mzjhd4lc4q44j8dmpc1w5p";
name = "mingw-w64-clang-x86_64-gettext-0.22.4-3-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-zlib-1.3.1-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-xz-5.4.5-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libxml2-2.12.8-1-any.pkg.tar.zst";
sha256 = "1niky9s7qq0434ljma3f9m8yybcifkvkxwhk580crzb2ifpmqxc3"; sha256 = "1imipb0dz4w6x4n9arn22imyzzcwdlf2cqxvn7irqq7w9by6fy0b";
name = "mingw-w64-clang-x86_64-xz-5.4.5-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libxml2-2.12.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libxml2-2.12.1-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-zstd-1.5.6-2-any.pkg.tar.zst";
sha256 = "0lajdhkqv3hrfmnf9wa0ddqpr9z8rvb037rv0ss60g2n0b92s5gk"; sha256 = "02cp5ci8w50k7xn38mpkwnr8sn898v18wcc07y8f9sfla7vcyfix";
name = "mingw-w64-clang-x86_64-libxml2-2.12.1-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-zstd-1.5.6-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-zstd-1.5.5-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-18.1.8-1-any.pkg.tar.zst";
sha256 = "07739wmwgxf0d6db4p8w302a6jwcm01aafr1s8jvcl5k1h5a1m2m"; sha256 = "0rpbgvvinsqflhd3nhfxk0g0yy8j80zzw5yx6573ak0m78a9fa06";
name = "mingw-w64-clang-x86_64-zstd-1.5.5-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-llvm-libs-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-libs-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-18.1.8-1-any.pkg.tar.zst";
sha256 = "13gpp9ah8g4ggzmgsskwcrck16rvjwp464i5vyw9zwm88bdhmz93"; sha256 = "185g5h8q3x3rav9lp2njln58ny2idh2067fd02j3nsbik6glshpf";
name = "mingw-w64-clang-x86_64-llvm-libs-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-llvm-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-libs-18.1.8-1-any.pkg.tar.zst";
sha256 = "103zfg4ypmzhy0gd02arjydca2l4ak2mai3g1xrck83l15v85f9h"; sha256 = "089hji3yd7wsd03v9mdfgc99l5k1dql8kg7p3hy13vrbgfsabxhc";
name = "mingw-w64-clang-x86_64-llvm-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-clang-libs-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-compiler-rt-18.1.8-1-any.pkg.tar.zst";
sha256 = "0n38jb9gl3zmy7f07bqzpk70k00hdrnasrclhcxgbn4ylzlca07w"; sha256 = "1dwcxnv1k5ljim5ys4h1c3jlrdpi0054z094ynav7if65i8zjj4a";
name = "mingw-w64-clang-x86_64-compiler-rt-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-compiler-rt-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-headers-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "1wisa5j86xn8fanx4ks5pkj2hx0k89cippcap6c6yf6d3qmj4mvr"; sha256 = "1h3cdcajz29iq7vja908kkijz1vb9xn0f7w1lw1ima0q0zhinv4q";
name = "mingw-w64-clang-x86_64-headers-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-headers-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-crt-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "0x1x3rnxivm58sqjx5v5smnzl5hdinx4qrjgkn59nylcd24fvnk4"; sha256 = "15kamyi3b0j6f5zxin4i2jgzjc7lzvwl4z5cz3dx0i8hg91aq0n7";
name = "mingw-w64-clang-x86_64-crt-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-crt-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-lld-18.1.8-1-any.pkg.tar.zst";
sha256 = "1028iymsay6smk856xfyccl7pkcxzigqlrhj4sgz4xxqij7xqnp0"; sha256 = "1vpij5d06m4kjy3qv8bizwlkl21gcv6fv0r2f1j9bclgm6k3144x";
name = "mingw-w64-clang-x86_64-lld-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-lld-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "01q7r8mvn60yc0bxlhfbb1rx8w6lbvf1bd0xfaycqmc25b8vnfnp"; sha256 = "0qdvgs1rmjjhn9klf9kpw7l0ydz36rr5fasn4q9gpby2lgl11bkb";
name = "mingw-w64-clang-x86_64-libwinpthread-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libwinpthread-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
sha256 = "0f5828injyxhy8vxkv02wmk4fh6x5gsqnkr97p50vz692fxkwk48"; sha256 = "0rh2mn078cifcmr4as4k57jxjln5lbnsmpx47h9d0s5d2i8sf2rc";
name = "mingw-w64-clang-x86_64-winpthreads-git-11.0.0.r404.g3a137bd87-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-winpthreads-git-12.0.0.r81.g90abf784a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-clang-18.1.8-1-any.pkg.tar.zst";
sha256 = "1adfijlky02waz6gm2rx32rd413ws5rphkpybihpc8hi11yag761"; sha256 = "1qny934nv4g75k9gb5sf31v24bgafkg6qw7r35xv3in491w6annq";
name = "mingw-w64-clang-x86_64-clang-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-clang-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.22.1-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-c-ares-1.29.0-1-any.pkg.tar.zst";
sha256 = "0bv8n3862krxhlmz0lxqq71y40dy4flpjvb7adl5a96jixgsbxav"; sha256 = "01xg1h1a8kda0kq2921w25ybvm1ms7lfdzday0hv93f3myq7briq";
name = "mingw-w64-clang-x86_64-c-ares-1.22.1-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-c-ares-1.29.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -127,21 +127,21 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.1-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst";
sha256 = "16myvbg33q5s7jl30w5qd8n8f1r05335ms8r61234vn52n32l2c4"; sha256 = "13nz49li39z1zgfx1q9jg4vrmyrmqb6qdq0nqshidaqc6zr16k3g";
name = "mingw-w64-clang-x86_64-libunistring-1.1-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libunistring-1.2-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libidn2-2.3.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libidn2-2.3.7-2-any.pkg.tar.zst";
sha256 = "105valrldri39sx7d5zdscxgsz9px382f8vbbl2zpr2xzb3jq8p8"; sha256 = "07k8zh5nb2s82md7lz22r8gim8214rhlg586lywck3zcla98jv1w";
name = "mingw-w64-clang-x86_64-libidn2-2.3.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libidn2-2.3.7-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libpsl-0.21.2-4-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libpsl-0.21.5-2-any.pkg.tar.zst";
sha256 = "0h4inq6prhiipl3h5k9br9rz5mih123b8n12wiwp2qck0h2q3x98"; sha256 = "1mpx77q5g8pj45s8wgc52c4ww2r93080p6d559p56f558a3cl317";
name = "mingw-w64-clang-x86_64-libpsl-0.21.2-4-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libpsl-0.21.5-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -151,21 +151,21 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-p11-kit-0.25.3-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-p11-kit-0.25.5-1-any.pkg.tar.zst";
sha256 = "19330k2jxpzm552m7j116hz8qrn2d14h9ybi0lcmkj4c1l25m6mf"; sha256 = "00yz6cmr1ldlrskv811n345xcia88mj7w4fyx4m9z5848jxgsabd";
name = "mingw-w64-clang-x86_64-p11-kit-0.25.3-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-p11-kit-0.25.5-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20230311-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst";
sha256 = "00hdl239695xi5bgld7a1ssp6kapkb9az02dpx80vmz7mqg6wwxx"; sha256 = "1q5nxhsk04gidz66ai5wgd4dr04lfyakkfja9p0r5hrgg4ppqqjg";
name = "mingw-w64-clang-x86_64-ca-certificates-20230311-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-ca-certificates-20240203-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openssl-3.2.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openssl-3.3.1-1-any.pkg.tar.zst";
sha256 = "1623bkf4bjwf4cs1j5f6c37gwj4z775kjk4jswwy58r61yimwnd5"; sha256 = "0ywhwm4kw3qjzv0872qwabnsq2rzbmqjb9m69q3fykjl0m9gigsa";
name = "mingw-w64-clang-x86_64-openssl-3.2.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-openssl-3.3.1-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -175,33 +175,45 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp2-1.58.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp2-1.61.0-2-any.pkg.tar.zst";
sha256 = "00f7raqmky43v9h2356yzfbvbhkbsjpc17n6ksgpv66h31svna6q"; sha256 = "07bkk98126gy4k6lb9rrqqnzjfz9j2rsr5dzr2djmzdkw0h4dr95";
name = "mingw-w64-clang-x86_64-nghttp2-1.58.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-nghttp2-1.61.0-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-curl-8.4.0-2-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-nghttp3-1.4.0-1-any.pkg.tar.zst";
sha256 = "1qi446lq1qqq45xn7fmasfb914jhahqc0cvpnqi8vrp9vlk9dsw3"; sha256 = "007w2252nzn274j4wjc1vf56xyzzh5vg3blj1hil7mlmffgvc923";
name = "mingw-w64-clang-x86_64-curl-8.4.0-2-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-nghttp3-1.4.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.74.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-curl-8.8.0-10-any.pkg.tar.zst";
sha256 = "071xpmfsnr7j7ycf6w8kyg41fxs3dclnjrcp7grinck5bb5zxn6s"; sha256 = "024z5b1achkf448gxqy1i3gcw371x54kfl6igv08b5wb3rrw35a4";
name = "mingw-w64-clang-x86_64-rust-1.74.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-curl-8.8.0-10-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-pkgconf-1~2.1.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rust-1.79.0-1-any.pkg.tar.zst";
sha256 = "04y364mzx2sj984cdxq187fg73hzkrzvbpsr5d86xfzrag789nh3"; sha256 = "0i7s88hj8m4920xifkj7i4b4sq8cqq7p5cypp3jqx3dc44pwm19a";
name = "mingw-w64-clang-x86_64-pkgconf-12.1.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-rust-1.79.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-jsoncpp-1.9.5-2-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-pkgconf-1~2.2.0-1-any.pkg.tar.zst";
sha256 = "0cpy76crngj5dfg9f4l216ry0wcavp0nabyc0b9g676rg6400qas"; sha256 = "1y44ijg3y8p80f1yn9972nshrnyrd06a9sh984ajhxg8bi8s5xyl";
name = "mingw-w64-clang-x86_64-jsoncpp-1.9.5-2-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-pkgconf-12.2.0-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-expat-2.6.2-1-any.pkg.tar.zst";
sha256 = "0kj1vzjh3qh7d2g47avlgk7a6j4nc62111hy1m63jwq0alc01k38";
name = "mingw-w64-clang-x86_64-expat-2.6.2-1-any.pkg.tar.zst";
})
(pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-jsoncpp-1.9.5-3-any.pkg.tar.zst";
sha256 = "1a8mdn4ram9pgqpx5fwxmhcmzc6bh1fq1s4m37xh0d8p6fpncv10";
name = "mingw-w64-clang-x86_64-jsoncpp-1.9.5-3-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -223,57 +235,57 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libtre-git-r128.6fb7206-2-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libtre-git-r177.07e66d0-2-any.pkg.tar.zst";
sha256 = "0bmdla75k0q88l93ql9ajbfag4vhdhyp0glzymljvcc7ir0r6f9r"; sha256 = "0fc9hxsdks1xy5fv0rcna433hlzf6jhs77hg0hfzkzhn06f9alp4";
name = "mingw-w64-clang-x86_64-libtre-git-r128.6fb7206-2-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libtre-git-r177.07e66d0-2-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libsystre-1.0.1-4-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libsystre-1.0.1-5-any.pkg.tar.zst";
sha256 = "0x5ns8ld08gd9r5a98sqi3sm0vz588caaqw10ciix16sgyisfpdh"; sha256 = "05qsn8fkks4f93jkas43s47axqqgx5m64b45p462si3nlb8cjirq";
name = "mingw-w64-clang-x86_64-libsystre-1.0.1-4-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libsystre-1.0.1-5-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libarchive-3.7.2-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libarchive-3.7.4-1-any.pkg.tar.zst";
sha256 = "1p84yh6yzkdpmr02vyvgz16x5gycckah25jkdc2py09l7iw96bmw"; sha256 = "1ykw6imllgxv6lsgwxx1miqjr4l1iryqkrj286jcbfrb8ghpzhv5";
name = "mingw-w64-clang-x86_64-libarchive-3.7.2-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libarchive-3.7.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libuv-1.47.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-libuv-1.48.0-1-any.pkg.tar.zst";
sha256 = "1ch8g0mp13dmwi7sc6qxjcx18s1w0bsdl5lhkjmx5ccz55sspnyf"; sha256 = "0kfzanvx7hg7bvy35h2z2vcfxvwn44sikd36mvzhkv6c3c6y84sn";
name = "mingw-w64-clang-x86_64-libuv-1.47.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-libuv-1.48.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ninja-1.11.1-3-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ninja-1.12.1-1-any.pkg.tar.zst";
sha256 = "13wjfmyfr952n3ydpldjlwx1nla5xpyvr96ng8pfbyw4z900v5ms"; sha256 = "1vj9qaa43v316daz8k4ricmz3f33nhjpj7r0vn979nwmy7hzs7jx";
name = "mingw-w64-clang-x86_64-ninja-1.11.1-3-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-ninja-1.12.1-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.3-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst";
sha256 = "16ghg6894gb3lcrcpc8g1jd7524djnsjrqcr3krqlzskfv51hgj6"; sha256 = "1ysbxirpfr0yf7pvyps75lnwc897w2a2kcid3nb4j6ilw6n64jmc";
name = "mingw-w64-clang-x86_64-rhash-1.4.3-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-rhash-1.4.4-3-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.27.8-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-cmake-3.30.0-1-any.pkg.tar.zst";
sha256 = "1hm5975q0kfd0z502qpxq1imwvby3qb93jzhbvrg2fz9zf5d31rg"; sha256 = "07b7132hwhiqrf0l2lgw3g4zw9i2lln3kqc9kg2qijvkapbkmwqb";
name = "mingw-w64-clang-x86_64-cmake-3.27.8-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-cmake-3.30.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-mpdecimal-2.5.1-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-mpdecimal-4.0.0-1-any.pkg.tar.zst";
sha256 = "0zm9pqcgjk9a8ql6hxcxh477wvvyimndcisd6zrr6y8m5vvklsdi"; sha256 = "0hrhbjgi0g3jqpw8himshqw6vazm5sxhsfmyg386nbrxwnfgl1gb";
name = "mingw-w64-clang-x86_64-mpdecimal-2.5.1-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-mpdecimal-4.0.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.4.20230708-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-ncurses-6.4.20231217-1-any.pkg.tar.zst";
sha256 = "0l960cf4m558kx6dwahl9z2zdb9vlqd7qgk39kg924mzrfjrsvv5"; sha256 = "00046d52zsr8zjifl7h22jfihhh53h20ipvbqmvf9myssw2fwjza";
name = "mingw-w64-clang-x86_64-ncurses-6.4.20230708-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-ncurses-6.4.20231217-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
@ -283,62 +295,62 @@
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-readline-8.2.001-6-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-readline-8.2.010-1-any.pkg.tar.zst";
sha256 = "1aw1qxjvifkzxnn52l6hba1kcj2dci8llzzj29zbvbq5hgllf6dv"; sha256 = "1s47pd5iz8y3hspsxn4pnp0v3m05ccia40v5nfvx0rmwgvcaz82v";
name = "mingw-w64-clang-x86_64-readline-8.2.001-6-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-readline-8.2.010-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-tcl-8.6.12-2-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-tcl-8.6.13-1-any.pkg.tar.zst";
sha256 = "06r7ni7m1vf83ms8ha3glalc5rfriiffh7v644jmnvzs5g9x0pzb"; sha256 = "0paaqwk0sfy2zxwlxkmxf2bqq46lyg0sx7cqgzknvazwx8xa2z4x";
name = "mingw-w64-clang-x86_64-tcl-8.6.12-2-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-tcl-8.6.13-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-sqlite3-3.44.0-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-sqlite3-3.46.0-1-any.pkg.tar.zst";
sha256 = "022vy54ph6kxmgafbsmgpclvd55ljarvc4k7497ka03g7465zq6j"; sha256 = "0q676i2z5nr4c71jnd4z5qz9xa1xryl0cpi84w74yvd0p4qiz7y2";
name = "mingw-w64-clang-x86_64-sqlite3-3.44.0-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-sqlite3-3.46.0-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-tk-8.6.12-2-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-tk-8.6.13-1-any.pkg.tar.zst";
sha256 = "0pi74q91vl6vw8vvmmwnvrgai3b1aanp0zhca5qsmv8ljh2wdgzx"; sha256 = "12f6lqx1sglczcnz2ns6sxw9cxwm1klxajqzcrbnfwln1nllz2nd";
name = "mingw-w64-clang-x86_64-tk-8.6.12-2-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-tk-8.6.13-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-tzdata-2023c-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-tzdata-2024a-1-any.pkg.tar.zst";
sha256 = "018qw8lk5r7gqvavnl1414gmbd4bx22528a7bjk4injfr3cib9c2"; sha256 = "1lsfn3759cyf56zlmfvgy6ihs4iks6zhlnrbfmnq5wml02k936ji";
name = "mingw-w64-clang-x86_64-tzdata-2023c-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-tzdata-2024a-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.11.6-2-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-3.11.9-1-any.pkg.tar.zst";
sha256 = "0j6wkjjay52f4mgvfagi7xql4v74dizhmwdx07n2fixr7j1lg6n7"; sha256 = "0ah1idjqxg7jc07a1gz9z766rjjd0f0c6ri4hpcsimsrbj1zjd3c";
name = "mingw-w64-clang-x86_64-python-3.11.6-2-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-python-3.11.9-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openmp-17.0.4-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-llvm-openmp-18.1.8-1-any.pkg.tar.zst";
sha256 = "1y2cf43fcyb1qv4kn0h0mzljmibaqklcj167krwxq5ymfbipy1yw"; sha256 = "0cy2v0l4af24j34mzj5q5nlzcqhackfajlfj1rpf6mb3rbz23qw9";
name = "mingw-w64-clang-x86_64-openmp-17.0.4-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-llvm-openmp-18.1.8-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.25-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-openblas-0.3.27-1-any.pkg.tar.zst";
sha256 = "1789hd7sc249vbzppqqg56i080c8kc6bj6qngn9rjvzcdllh0d07"; sha256 = "06ygz1wa488wqvmxbn74b0fyan4wf3lb6kbwfampgikd1gijww2k";
name = "mingw-w64-clang-x86_64-openblas-0.3.25-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-openblas-0.3.27-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-1.26.2-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-numpy-1.26.4-1-any.pkg.tar.zst";
sha256 = "1njbxlxha05c8fzm7yl3ksjlpanh5jxn1z1s06iv9knx2zmqjk6f"; sha256 = "00h0ap954cjwlsc3p01fjwy7s3nlzs90v0kmnrzxm0rljmvn4jkf";
name = "mingw-w64-clang-x86_64-python-numpy-1.26.2-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-python-numpy-1.26.4-1-any.pkg.tar.zst";
}) })
(pkgs.fetchurl { (pkgs.fetchurl {
url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-68.2.2-1-any.pkg.tar.zst"; url = "https://mirror.msys2.org/mingw/clang64/mingw-w64-clang-x86_64-python-setuptools-70.2.0-1-any.pkg.tar.zst";
sha256 = "16m1ph7yq3mqzaaaka5vij27jzaw59lpxwbmpdxkcp8lb6h5xcn9"; sha256 = "1q4r9bg2hn3jmshvq81xm5zvy9wn35yf0z2ayksrkwph1zzdkvkm";
name = "mingw-w64-clang-x86_64-python-setuptools-68.2.2-1-any.pkg.tar.zst"; name = "mingw-w64-clang-x86_64-python-setuptools-70.2.0-1-any.pkg.tar.zst";
}) })
] ]

View File

@ -1,3 +1,13 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)]
use std::env; use std::env;
static mut NOW: i64 = 0; static mut NOW: i64 = 0;
@ -29,17 +39,17 @@ pub extern "C" fn rtio_get_counter() -> i64 {
#[no_mangle] #[no_mangle]
pub extern "C" fn rtio_output(target: i32, data: i32) { pub extern "C" fn rtio_output(target: i32, data: i32) {
println!("rtio_output @{} target={:04x} data={}", unsafe { NOW }, target, data); println!("rtio_output @{} target={target:04x} data={data}", unsafe { NOW });
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn print_int32(x: i32) { pub extern "C" fn print_int32(x: i32) {
println!("print_int32: {}", x); println!("print_int32: {x}");
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn print_int64(x: i64) { pub extern "C" fn print_int64(x: i64) {
println!("print_int64: {}", x); println!("print_int64: {x}");
} }
#[no_mangle] #[no_mangle]
@ -47,12 +57,11 @@ pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _conte
unimplemented!(); unimplemented!();
} }
fn main() { fn main() {
let filename = env::args().nth(1).unwrap(); let filename = env::args().nth(1).unwrap();
unsafe { unsafe {
let lib = libloading::Library::new(filename).unwrap(); let lib = libloading::Library::new(filename).unwrap();
let func: libloading::Symbol<unsafe extern fn()> = lib.get(b"__modinit__").unwrap(); let func: libloading::Symbol<unsafe extern "C" fn()> = lib.get(b"__modinit__").unwrap();
func() func();
} }
} }