Compare commits

..

650 Commits

Author SHA1 Message Date
5839badadd [standalone] Update globals.py with type-inferred global var 2024-10-07 20:44:08 +08:00
56c845aac4 [standalone] Add support for registering globals without type decl 2024-10-07 20:44:06 +08:00
65a12d9ab3 [core] Refactor registration of top-level variables 2024-10-07 17:05:48 +08:00
9c6685fa8f [core] typecheck/function_check: Fix lookup of defined ids in scope 2024-10-07 16:51:37 +08:00
2bb788e4bb [core] codegen/expr: Materialize implicit globals
Required for when globals are read without the use of a global
declaration.
2024-10-07 13:13:20 +08:00
42a2f243b5 [core] typecheck: Disallow redeclaration of var shadowing global 2024-10-07 13:11:00 +08:00
3ce2eddcdc [core] typecheck/type_inferencer: Infer whether variables are global 2024-10-07 13:10:46 +08:00
51bf126a32 [core] typecheck/type_inferencer: Differentiate global symbols
Required for analyzing use of global symbols before global declaration.
2024-10-07 12:25:00 +08:00
1a197c67f6 [core] toplevel/composer: Reduce lock scope while analyzing function 2024-10-05 15:53:20 +08:00
581b2f7bb2 [standalone] Add demo for global variables 2024-10-04 13:24:30 +08:00
746329ec5d [standalone] Implement symbol resolution for globals 2024-10-04 13:24:30 +08:00
e60e8e837f [core] Add support for global statements 2024-10-04 13:24:27 +08:00
9fdbe9695d [core] Add generator to SymbolResolver::get_symbol_value
Needed in a future commit.
2024-10-04 13:20:29 +08:00
8065e73598 [core] toplevel/composer: Add type analysis for global variables 2024-10-04 13:20:29 +08:00
192290889b [core] Add IdentifierInfo
Keeps track of whether an identifier refers to a global or local
variable.
2024-10-04 13:20:24 +08:00
1407553a2f [core] Implement parsing of global variables
Globals are now parsed into symbol resolver and top level definitions.
2024-10-04 13:18:29 +08:00
c7697606e1 [core] Add TopLevelDef::Variable 2024-10-04 13:09:25 +08:00
88d0ccbf69 [standalone] Explicit panic when encountering a compilation error
Otherwise scripts will continue to execute.
2024-10-04 13:00:16 +08:00
a43b59539c [meta] Move variables declarations closer to where they are first used 2024-10-04 13:00:16 +08:00
fe06b2806f [meta] Reorganize order of use declarations
Use declarations are now grouped into 4 groups:

- Declarations from the standard library
- Declarations from external crates
- Declarations from other crates in this project
- Declarations from within this module

Furthermore, all use declarations are grouped together to enhance
readability. super::super is also replaced by an equivalent crate::
declaration.
2024-10-04 12:52:01 +08:00
7f6c9a25ac [meta] Update Cargo dependencies 2024-10-04 12:52:01 +08:00
6c8382219f msys2: get python via numpy dependencies 2024-09-30 14:27:30 +08:00
9274a7b96b flake: update nixpkgs 2024-09-30 14:22:40 +08:00
d1c0fe2900 cargo: update dependencies 2024-09-30 14:14:43 +08:00
f2c047ba57 artiq: support async rpcs
Co-authored-by: mwojcik <mw@m-labs.hk>
Co-committed-by: mwojcik <mw@m-labs.hk>
2024-09-13 12:12:13 +08:00
5e2e77a500 [meta] Bump inkwell to v0.5 2024-09-13 11:11:14 +08:00
f3cc4702b9 [meta] Update dependencies 2024-09-13 11:11:14 +08:00
3e92c491f5 [standalone] Add tests creating ndarrays with tuple dims 2024-09-11 15:52:43 +08:00
7f629f1579 core: fix comment in unify_call 2024-09-11 15:46:19 +08:00
5640a793e2 core: allow np_full to take tuple shapes 2024-09-11 15:46:19 +08:00
abbaa506ad [standalone] Remove redundant recreation of TargetMachine 2024-09-09 14:27:10 +08:00
f3dc02d646 [meta] Apply cargo fmt 2024-09-09 14:24:52 +08:00
ea217eaea1 [meta] Update pre-commit config
Directly invoke cargo using nix develop to avoid using the system cargo.
2024-09-09 14:24:38 +08:00
5a34551905 allow the use of the LLVM shared library
Which in turns allows working around the incompatibility of the LLVM static library
with Rust link-args=-rdynamic, which produces binaries that either fail to link (OpenBSD)
or segfault on startup (Linux).

The year is 2024 and compiler toolchains are still a trash fire like this.
2024-09-09 11:17:31 +08:00
6098b1b853 fix previous commit 2024-09-06 11:32:08 +08:00
668ccb1c95 nac3core: expose inkwell and nac3parser 2024-09-06 11:06:26 +08:00
a3c624d69d update all dependencies 2024-09-06 10:21:58 +08:00
bd06155f34 irrt: compatibility with pre-C23 compilers 2024-09-05 18:54:55 +08:00
9c33c4209c [core] Fix type of ndarray.element_type
Should be the element type of the NDArray itself, not the pointer to its
type.
2024-08-30 22:47:38 +08:00
122983f11c flake: update dependencies 2024-08-30 14:45:38 +08:00
71c3a65a31 [core] codegen/stmt: Fix obtaining return type of sret functions 2024-08-29 19:15:30 +08:00
8c540d1033 [core] codegen/stmt: Add more casts for boolean types 2024-08-29 16:36:32 +08:00
0cc60a3d33 [core] codegen/expr: Fix missing cast to i1 2024-08-29 16:36:32 +08:00
a59c26aa99 [artiq] Fix RPC of ndarrays from host 2024-08-29 16:08:45 +08:00
02d93b11d1 [meta] Update dependencies 2024-08-29 14:32:21 +08:00
59cad5bfe1
standalone: clang-format demo.c 2024-08-29 10:37:24 +08:00
4318f8de84
standalone: improve src/assignment.py 2024-08-29 10:33:58 +08:00
15ac00708a [core] Use quoted include paths instead of angled brackets
This is preferred for user-defined headers.
2024-08-28 16:37:03 +08:00
c8dfdcfdea
standalone & artiq: remove class_names from resolver 2024-08-27 23:43:40 +08:00
600a5c8679 Revert "standalone: reformat demo.c"
This reverts commit 308edb8237.
2024-08-27 23:06:49 +08:00
22c4d25802 core/typecheck: add missing typecheck in matmul 2024-08-27 22:59:39 +08:00
308edb8237 standalone: reformat demo.c 2024-08-27 22:55:22 +08:00
9848795dcc core/irrt: add exceptions and debug utils 2024-08-27 22:55:22 +08:00
58222feed4 core/irrt: split into headers 2024-08-27 22:55:22 +08:00
518f21d174 core/irrt: build.rs capture IR defined constants 2024-08-27 22:55:22 +08:00
e8e49684bf core/irrt: build.rs capture IR defined types 2024-08-27 22:55:22 +08:00
b2900b4883 core/irrt: use +std=c++20 to compile
To explicitly set the C++ variant and avoid inconsistencies.
2024-08-27 22:55:22 +08:00
c6dade1394 core/irrt: reformat 2024-08-27 22:55:22 +08:00
7e3fcc0845 add .clang-format 2024-08-27 22:55:22 +08:00
d3b4c60d7f core/irrt: comment build.rs & move irrt to nac3core/irrt 2024-08-27 22:55:22 +08:00
5b2b6db7ed core: improve error messages 2024-08-26 18:37:55 +08:00
15e62f467e standalone: add tests for polymorphism 2024-08-26 18:37:55 +08:00
2c88924ff7 core: add support for simple polymorphism 2024-08-26 18:37:55 +08:00
a744b139ba core: allow Call and AnnAssign in init block 2024-08-26 18:37:55 +08:00
2b2b2dbf8f [core] Fix resolution of exception names in raise short form
Previous implementation fails as `resolver.get_identifier_def` in ARTIQ
would return the exception __init__ function rather than the class.

We fix this by limiting the exception class resolution to only include
raise statements, and to force the exception name to always be treated
as a class.

Fixes #501.
2024-08-26 18:35:02 +08:00
d9f96dab33 [core] Add codegen_unreachable 2024-08-23 13:10:55 +08:00
c5ae0e7c36 [standalone] Add tests for tuple equality 2024-08-21 16:25:32 +08:00
b8dab6cf7c [standalone] Add tests for string equality 2024-08-21 16:25:32 +08:00
4d80ba38b7 [core] codegen/expr: Implement comparison of tuples 2024-08-21 16:25:32 +08:00
33929bda24 [core] typecheck/typedef: Add support for tuple methods 2024-08-21 16:25:32 +08:00
a8e92212c0 [core] codegen/expr: Implement string equality 2024-08-21 16:25:32 +08:00
908271014a [core] typecheck/magic_methods: Add equality methods to string 2024-08-21 16:25:32 +08:00
c407622f5c [core] codegen/expr: Add compilation error for unsupported cmpop 2024-08-21 15:46:13 +08:00
d7952d0629 [core] codegen/expr: Fix assertions not generated for -O0 2024-08-21 15:36:54 +08:00
ca1395aed6 [core] codegen: Remove redundant return 2024-08-21 15:36:54 +08:00
7799aa4987 [meta] Do not specify rev in dependency version 2024-08-21 15:36:54 +08:00
76016a26ad [meta] Apply clippy suggestions 2024-08-21 13:07:57 +08:00
8532bf5206
standalone: add missing test_ndarray_ceil() run 2024-08-21 11:39:00 +08:00
2cf64d8608
apply clippy comment changes 2024-08-21 11:21:10 +08:00
706759adb2
artiq: apply cargo fmt 2024-08-21 11:21:10 +08:00
b90cf2300b
core/fix: add missing lifetime in gen_for* 2024-08-21 11:05:30 +08:00
0fc26df29e flake: update nixpkgs 2024-08-19 23:53:15 +08:00
0b074c2cf2 [artiq] symbol_resolver: Set private linkage for constants 2024-08-19 14:41:43 +08:00
a0f6961e0e cargo: update dependencies 2024-08-19 13:15:03 +08:00
b1c5c2e1d4 [artiq] Fix RPC of ndarrays to host 2024-08-15 15:41:24 +08:00
69320a6cf1 [artiq] Fix LLVM representation of strings
Should be `%str` rather than `[N x i8]`.
2024-08-14 09:30:08 +08:00
9e0601837a core: Add compile-time feature to disable escape analysis 2024-08-14 09:29:48 +08:00
432c81a500
core: update insta after #489 2024-08-13 15:30:34 +08:00
6beff7a268 [artiq] Implement core_log and rtio_log in terms of polymorphic_print
Implementation mostly references the original implementation in Python.
2024-08-13 15:19:03 +08:00
6ca7aecd4a [artiq] Add core_log and rtio_log function declarations 2024-08-13 15:19:03 +08:00
8fd7216243 [core] toplevel/composer: Add lateinit_builtins
This is required for the new core_log and rtio_log functions, which take
a generic type as its parameter. However, in ARTIQ builtins are
initialized using one unifier and then actually used by another unifier.

lateinit_builtins workaround this issue by deferring the initialization
of functions requiring type variables until the actual unifier is ready.
2024-08-13 15:19:03 +08:00
4f5e417012 [core] codegen: Add function to get format constants for integers 2024-08-13 15:19:03 +08:00
a0614bad83 [core] codegen/expr: Make gen_string return StructValue
So that it is clear that the value itself is returned rather than a
pointer to the struct or its data.
2024-08-13 15:19:03 +08:00
5539d144ed [core] Add CodeGenContext::build_in_bounds_gep_and_load
For safer accesses to `gep`-able values and faster fails.
2024-08-13 15:19:03 +08:00
b3891b9a0d standalone: Fix several issues post script refactoring
- Add helptext for check_demos.sh
- Add back support for using debug NAC3 for running tests
- Output error message when argument is not recognized
- Fixed last non-demo script argument being ignored
- Add back SSE2 requirement to NAC3 (required for mandelbrot)
2024-08-13 15:19:03 +08:00
6fb8939179 [meta] Update dependencies 2024-08-13 15:19:03 +08:00
973dc5041a core/typecheck: Support tuple arg type in len() 2024-08-13 15:02:59 +08:00
d0da688aa7 standalone: Add tuple len test 2024-08-13 15:02:59 +08:00
12c4e1cf48 core/toplevel/builtins: Add support for len() on tuples 2024-08-13 15:02:59 +08:00
9b988647ed core/toplevel/builtins: Extract len() into builtin function 2024-08-13 15:02:59 +08:00
35a7cecc12
core/typecheck: fix np_array ndmin bug 2024-08-13 12:50:04 +08:00
7e3d87f841 core/codegen: fix bug in call_ceil function 2024-08-07 16:40:55 +08:00
ac0d83ef98 standalone: Add vararg.py 2024-08-06 11:48:42 +08:00
3ff6db1a29 core/codegen: Add va_start and va_end intrinsics 2024-08-06 11:48:42 +08:00
d7b806afb4 core/codegen: Implement support for va_info on supported architectures 2024-08-06 11:48:40 +08:00
fac60c3974 core/codegen: Handle vararg in function generation 2024-08-06 11:46:00 +08:00
f5fb504a15 core/codegen/expr: Implement vararg handling in gen_call 2024-08-06 11:46:00 +08:00
faa3bb97ad core/typecheck/typedef: Add vararg to Unifier::stringify 2024-08-06 11:46:00 +08:00
6a64c9d1de core/typecheck/typedef: Add is_vararg_ctx to TTuple 2024-08-06 11:45:54 +08:00
3dc8498202 core/typecheck/typedef: Handle vararg parameters in unify_call 2024-08-06 11:43:13 +08:00
cbf79c5e9c core/typecheck/typedef: Add is_vararg to FuncArg, ConcreteFuncArg 2024-08-06 11:43:13 +08:00
b8aa17bf8c core/toplevel/composer: Add parsing for vararg parameter 2024-08-06 10:52:24 +08:00
f5b998cd9c core/codegen: Remove unnecessary mut from get_llvm*_type 2024-08-06 10:52:24 +08:00
c36f85ecb9 meta: Update dependencies 2024-08-06 10:52:24 +08:00
3a8c385e01 core/typecheck: fix missing ExprKind::Asterisk in fix_assignment_target_context 2024-08-05 19:30:48 +08:00
221de4d06a core/codegen: add missing comment 2024-08-05 19:30:48 +08:00
fb9fe8edf2 core: reimplement assignment type inference and codegen
- distinguish between setitem and getitem
- allow starred assignment targets, but the assigned value would be a tuple
- allow both [...] and (...) to be target lists
2024-08-05 19:30:48 +08:00
894083c6a3 core/codegen: refactor gen_{for,comprehension} to match on iter type 2024-08-05 19:30:48 +08:00
669c6aca6b clean up and fix 32-bit demos 2024-08-05 19:04:25 +08:00
63d2b49b09 core: remove np_linalg_matmul 2024-08-05 11:44:55 +08:00
bf709889c4 standalone/demo: separate linalg functions from main workspace 2024-08-05 11:44:54 +08:00
1c72698d02 core: add np_linalg_det and np_linalg_matrix_power functions 2024-07-31 18:02:54 +08:00
54f883f0a5 core: implement np_dot using LLVM_IR 2024-07-31 15:53:51 +08:00
4a6845dac6 standalone: add np.transpose and np.reshape functions 2024-07-31 13:23:07 +08:00
00236f48bc core: add np.transpose and np.reshape functions 2024-07-31 13:23:07 +08:00
a3e6bb2292 core/helper: add linalg section 2024-07-31 13:23:07 +08:00
17171065b1 standalone: link linalg at runtime 2024-07-31 13:23:07 +08:00
540b35ec84 standalone: move linalg functions to demo 2024-07-31 13:23:05 +08:00
4bb00c52e3 core/builtin_fns: improve error reporting 2024-07-31 13:21:31 +08:00
faf07527cb standalone: add runtime implementation for linalg functions 2024-07-31 13:21:28 +08:00
d6a4d0a634 standalone: add linalg methods and tests 2024-07-29 16:48:06 +08:00
2242c5af43 core: add linalg methods 2024-07-29 16:48:06 +08:00
318a675ea6 standalone: Rename -m32 to -i386 2024-07-29 14:58:58 +08:00
32e52ce198 standalone: Revert using uint32_t as slice length
Turns out list and str have always been size_t.
2024-07-29 14:58:29 +08:00
665ca8e32d cargo: update dependencies 2024-07-27 22:24:56 +08:00
12c12b1d80 flake: update nixpkgs 2024-07-27 22:22:20 +08:00
72972fa909 core/toplevel: add more numpy categories 2024-07-27 21:57:47 +08:00
142cd48594 core/toplevel: reorder PrimDef::details 2024-07-27 21:57:47 +08:00
8adfe781c5 core/toplevel: fix PrimDef method names 2024-07-27 21:57:47 +08:00
339b74161b core/toplevel: reorganize PrimDef 2024-07-27 21:57:47 +08:00
8c5ba37d09 standalone: Add 32-bit execution tests to check_demo.sh 2024-07-26 13:35:40 +08:00
05a8948ff2 core: Minor cleanup to use ListValue APIs 2024-07-26 13:35:40 +08:00
6d171ec284 core: Add label name and hooks to gen_for functions 2024-07-26 13:35:40 +08:00
0ba68f6657 core: Set target triple and datalayout for each module
Fixes an issue with inconsistent pointer sizes causing crashes.
2024-07-26 13:35:40 +08:00
693b2a8863 core: Add support for 32-bit size_t on 64-bit targets 2024-07-26 13:35:40 +08:00
5faeede0e5 Determine size_t using LLVM target machine 2024-07-26 13:35:38 +08:00
266707df9d standalone: Add support for running 32-bit binaries 2024-07-26 13:32:38 +08:00
3d3c258756 standalone: Remove support for --lli 2024-07-26 13:32:38 +08:00
ed1182cb24 standalone: Update format specifiers for exceptions
Use platform-agnostic identifiers instead.
2024-07-26 13:32:37 +08:00
fd025c1137 standalone: Use uint32_t for cslice length
Matching the expected type of string and list slices.
2024-07-26 13:32:21 +08:00
f139db9af9 meta: Update dependencies 2024-07-26 10:33:02 +08:00
44487b76ae standalone: interpret_demo.py remove duplicated section 2024-07-22 17:23:35 +08:00
1332f113e8 standalone: fix interpret_demo.py comments 2024-07-22 17:06:14 +08:00
7632d6f72a cargo: update dependencies 2024-07-21 11:00:25 +08:00
4948395ca2 core/toplevel/type_annotation: Add handling for mismatching class def
Primitive types only contain fields in its Type and not its TopLevelDef.
This causes primitive object types to lack some fields.
2024-07-19 14:42:14 +08:00
3db3061d99 artiq/symbol_resolver: Handle type of zero-length lists 2024-07-19 14:42:14 +08:00
51c2175c80 core/codegen/stmt: Convert assertion values to i1 2024-07-19 14:42:14 +08:00
1a31a50b8a
standalone: fix __nac3_raise def in demo.c 2024-07-17 21:22:08 +08:00
6c10e3d056 core: cargo clippy 2024-07-12 21:18:53 +08:00
2dbc1ec659 cargo fmt 2024-07-12 21:16:38 +08:00
c80378063a add np_argmin/argmax to interpret_demo environment 2024-07-12 13:27:52 +02:00
513d30152b core: support raise exception short form 2024-07-12 18:58:34 +08:00
45e9360c4d standalone: Add np_argmax and np_argmin tests 2024-07-12 18:19:56 +08:00
2e01b77fc8 core: refactor np_max/np_min functions 2024-07-12 18:18:54 +08:00
cea7cade51 core: add np_argmax/np_argmin functions 2024-07-12 18:18:28 +08:00
d658d9b00e update dependencies, Python 3.12 on Linux 2024-07-09 23:56:12 +08:00
eeb474f9e6 core: reduce code duplication in codegen/extern_fns (#453)
Used macros to reduce code duplication in `codegen/extern_fns`

Reviewed-on: M-Labs/nac3#453
Co-authored-by: abdul124 <ar@m-labs.hk>
Co-committed-by: abdul124 <ar@m-labs.hk>
2024-07-09 16:31:08 +08:00
88b72af2d1 core/llvm_intrinsic: improve macro name and comments 2024-07-09 16:30:32 +08:00
b73f6c4d68 core: reduce code duplication in codegen/llvm_intrinsic 2024-07-09 16:30:32 +08:00
f47cdec650 standalone: Fix output format of output_range 2024-07-09 13:55:48 +08:00
d656880e44 standalone: Fix missing implementation for output_range 2024-07-09 13:53:50 +08:00
a91602915a core: Fix missing fields in range type 2024-07-09 13:53:50 +08:00
1c56005a01 core: Reformat and modernize irrt.cpp
- Use anon namespace instead of static
- Use using declaration instead of typedef
- Align pointers to the type instead of the identifier
2024-07-09 13:53:50 +08:00
bc40a32524 core: Add report_type_error to enable more code reuse 2024-07-09 13:44:47 +08:00
c820daf5f8 core: Apply cargo format 2024-07-09 13:32:10 +08:00
25d2de67f7 standalone: Add output_range and tests 2024-07-09 04:44:40 +08:00
2cfb7a7e10 core: Refactor range function into constructor 2024-07-09 04:44:40 +08:00
9238a5e86e standalone: Rename output_str to output_strln and add output_str
output_str is for outputting strings without newline, and the newly
introduced output_strln now has the old behavior of ending with a
newline.
2024-07-09 04:44:40 +08:00
76defac462 meta: use clang -x c++ instead of clang++ 2024-07-07 20:03:34 +08:00
650f354b74 core: use C++ for irrt source 2024-07-07 14:36:10 +08:00
f062ef5f59 core/llvm_intrinsic: replace roundeven with rint 2024-07-07 14:24:18 +08:00
f52086b706 core: improve binop and cmpop error messages 2024-07-05 16:27:24 +08:00
0a732691c9 core: refactor typecheck/magic_methods.rs operators & add op symbol name 2024-07-05 16:27:20 +08:00
cbff356d50 core: workaround inkwell on llvm.stackrestore 2024-07-05 13:56:12 +08:00
24ac3820b2 core: check int32 obj_id directly in fold_numpy_function_call_shape_argument 2024-07-05 10:36:47 +08:00
ba32fab374 standalone: Add demos for list arithmetic operators 2024-07-04 16:01:15 +08:00
c4052b6342 core: Implement multi-operand __eq__ and __ne__ for lists 2024-07-04 16:01:15 +08:00
66c205275f core: Implement list::__add__ 2024-07-04 16:01:11 +08:00
c85e412206 core: Implement list::__mul__ 2024-07-04 15:53:50 +08:00
075536d7bd core: Add BreakContinueHooks for gen_for_callback 2024-07-04 15:32:18 +08:00
13beeaa2bf core: Implement handling for zero-length lists 2024-07-04 15:32:18 +08:00
2194dbddd5 core/type_annotation: Refactor List type to TObj
In preparation for operators on lists.
2024-07-04 15:32:18 +08:00
94a1d547d6 meta: Update dependencies 2024-07-04 15:32:18 +08:00
d6565feed3 core: ndarray_from_ndlist_impl cast size_of to usize 2024-07-04 12:24:52 +08:00
83154ef8e1 core/llvm_intrinsics: remove llvm.roundeven call from call_float_roundeven 2024-07-03 14:17:47 +08:00
0744b938b8 core: fix __nac3_ndarray_calc_size crash due to incorrect typing 2024-07-03 13:03:14 +08:00
56fa2b6803 core: fix crash on iterating over non-iterables
a
2024-06-28 15:45:53 +08:00
d06c13f936 core: fix crash on invalid subscripting 2024-06-27 16:58:48 +08:00
9808923258 core: improve comments in type_inferencer/mod.rs 2024-06-27 14:46:48 +08:00
5b11a1dbdd core: support tuple and int32 input for np_empty, np_ones, and more 2024-06-27 14:30:17 +08:00
b21df53e0d core: fix comment typo in unify_call() 2024-06-27 14:06:39 +08:00
0ec967a468 core: improve function call errors 2024-06-27 14:06:39 +08:00
ca8459dc7b standalone: prettify TopLevelComposer error reporting 2024-06-27 10:15:14 +08:00
b0b804051a nac3artiq: allow class attribute access without init function 2024-06-25 16:06:33 +08:00
134af79fd6 core: add support for class attributes 2024-06-25 16:06:33 +08:00
7fe2c3496c core: add attribute field to class definition 2024-06-25 16:06:33 +08:00
144a3fc426 core: more derive Debug in typedef 2024-06-25 15:02:50 +08:00
74096eb9f6 core: name codegen worker threads 2024-06-25 12:36:37 +08:00
06e9d90d57 apply clippy changes 2024-06-21 14:14:01 +08:00
d89146aa02 core: use no_run on builtin_fns docs 2024-06-20 13:53:25 +08:00
5bade81ddb standalone: Add test for multidim array index with one index 2024-06-20 12:50:30 +08:00
0452e6de78 core: Fix codegen for tuple-index into ndarray 2024-06-20 12:50:30 +08:00
635c944c90 core: Fix type inference for tuple-index into ndarray
Fixes #420.
2024-06-20 12:50:30 +08:00
e36af3b0a3 core: reduce code duplication in codegen/builtin_fns (#422)
Used macros to generate some unary math functions.

Reviewed-on: M-Labs/nac3#422
Reviewed-by: David Mak <chmakac@connect.ust.hk>
Co-authored-by: lyken <lyken@m-labs.hk>
Co-committed-by: lyken <lyken@m-labs.hk>
2024-06-20 12:48:44 +08:00
5b1aa812ed update dependencies 2024-06-20 10:43:55 +08:00
d3cd2a8d99 artiq: Add support for generating RPC tag for ndarray 2024-06-19 18:56:16 +08:00
202a63274d artiq: Implement pyty-to-ty conversion 2024-06-19 18:56:15 +08:00
76dd5191f5 artiq: Implement Python-to-LLVM conversion of ndarray 2024-06-19 18:56:15 +08:00
8d9df0a615 artiq: Fix ndarray class ID
We want the class ID of the ndarray class, not its corresponding typing
class.
2024-06-19 18:56:15 +08:00
07adfb2a18 standalone: Add *.ll to Gitignore list 2024-06-19 18:56:15 +08:00
f00e458f60 add test for class without __init__ 2024-06-19 18:16:54 +08:00
1bc95a7ba6 Add handling for np.bool_ and np.str_ 2024-06-19 15:10:47 +08:00
e85f4f9bd2 core: refactor top_level::builtins::get_builtins() 2024-06-18 11:06:25 +08:00
ce3e9bf4fe nac3artiq: add support string attributes in classes 2024-06-17 16:53:51 +08:00
82091b1be8 meta: Apply clippy changes 2024-06-17 14:10:31 +08:00
32919949e2 Run clippy --tests on pre-commit hook 2024-06-17 12:51:25 +08:00
2abe75d1f4 core: remove code dup with make_exception_fields 2024-06-17 12:01:48 +08:00
676412fe6d apply cargo fmt 2024-06-14 09:46:42 +08:00
8b9df7252f core: cleanup with Unifier::generate_var_id 2024-06-14 09:42:04 +08:00
6979843431 core: fix typo in into_var_map 2024-06-13 16:59:10 +08:00
fed1361c6a core: rename to_var_map to into_var_map 2024-06-13 16:59:10 +08:00
aa94e0c8a4 core: remove pub & add From<TypeVarId> for u32 2024-06-13 16:59:10 +08:00
f523e26227 core: fix typo in fmt::Display of TypeVarId 2024-06-13 16:59:10 +08:00
f026b48e2a core: refactor to use TypeVarId and TypeVar 2024-06-13 16:59:10 +08:00
dc874f2994 core: use PrimDef simple names in make_primitives() 2024-06-13 16:58:32 +08:00
95de0800b4 core/demo: fix typo in .gitignore 2024-06-13 16:05:33 +08:00
3d71c6a850 core/demo: gitignore to ignore *.bc & *.o 2024-06-13 16:00:23 +08:00
be55e2ac80 meta: Update README to include info regarding pre-commit hooks 2024-06-12 16:10:57 +08:00
79c8b759ad meta: Add pre-commit configuration 2024-06-12 16:10:57 +08:00
4798c53a21 flake: Add pre-commit to dev environment 2024-06-12 16:10:57 +08:00
23974feae7 meta: Restrict number of allowed lints 2024-06-12 16:10:57 +08:00
40a3bded36 meta: Set clippy lints in {main,lib}.rs
So that this does not have to be manually passed to the `cargo clippy`
command-line every single time. Also allows incrementally addressing
these lints by removing and fixing them one-by-one.
2024-06-12 16:10:57 +08:00
c4420e6ab9 core: refactor get_builtins() 2024-06-12 15:09:20 +08:00
fd36f78005 core: refactor PrimitiveDefinitionId into enum PrimDef 2024-06-12 15:01:01 +08:00
8168692cc3 apply cargo fmt 2024-06-12 14:45:03 +08:00
53d44b9595 standalone: Add np_array tests 2024-06-11 16:44:36 +08:00
6153f94b05 core/numpy: Implement codegen for np_array 2024-06-11 16:42:11 +08:00
4730b595f3 core/builtins: Add np_array function 2024-06-11 16:42:08 +08:00
c2fdb12397 core/type_inferencer: Add special rule for np_array 2024-06-11 16:40:35 +08:00
82bf14785b core: Add multidimensional array helpers 2024-06-11 15:30:06 +08:00
2d4329e23c core/stmt: Use BB of last statement in if-else in phi 2024-06-11 15:30:06 +08:00
679656f9e1 core/classes: Fix incorrect field locations for lists 2024-06-11 15:30:06 +08:00
210d9e2334 core: Add more creator functions for ProxyType 2024-06-11 15:26:37 +08:00
181ac3ec1a core/classes: Fix incorrect pointers of range.{stop,step} 2024-06-11 15:13:31 +08:00
3acdfb304d meta: Apply clippy suggestions 2024-06-11 14:58:32 +08:00
6e24da9cc5 meta: Update dependencies 2024-06-11 14:58:32 +08:00
f0ab1b858a core: Refactor class abstractions
- Introduce new Type abstractions
- Rearrange some functions
2024-06-06 13:45:51 +08:00
08129cc635 nac3core: add TopLevelComposer::new builtin check's assertion msg 2024-06-05 15:30:02 +08:00
ad4832dcf4 core: Refactor to get LLVM intrinsics via Intrinsics::find 2024-06-05 15:29:40 +08:00
520bbb246b flake: add llvmPackages_14.llvm to devShells linux default (#405)
Co-authored-by: lyken <lyken@m-labs.hk>
Co-committed-by: lyken <lyken@m-labs.hk>
2024-06-05 11:11:56 +08:00
b857f1e403 nac3core: fix typo in gen_for's comment 2024-06-04 17:15:41 +08:00
fa8af37e84 flake: update nixpkgs 2024-06-03 22:22:04 +08:00
23b2fee4e7 standalone: Add test case for ndarray slicing 2024-06-03 16:40:05 +08:00
ed79d5bb9e core/expr: Add support for multi-dim slicing of NDArrays 2024-06-03 16:40:05 +08:00
c35ad06949 core/expr: Add support for 1D slicing of NDArrays 2024-06-03 16:40:05 +08:00
135ef557f9 core/numpy: Implement ndarray_sliced_{copy,copyto_impl}
Performing copying with optional support for slicing. Also made
copy_impl delegate to sliced_copy, as sliced_copy now performs a
superset of operations that copy_impl can already do.
2024-06-03 16:40:05 +08:00
a176c3eb70 core/irrt: Change handle_slice_indices to instead take length of object
So that all other array-like datatypes (e.g. ndarray) can also take
advantage of it.
2024-06-03 16:40:05 +08:00
2cf79510c2 core/numpy: Add more helper functions 2024-06-03 16:40:05 +08:00
b6ff75dcaf core/irrt: Add support for calculating partial size of NDArray 2024-06-03 16:40:05 +08:00
588c15f80d core/stmt: Add gen_for_range_callback
For generating for loops over range objects or array slices.
2024-06-03 16:40:05 +08:00
82cc693b11 meta: Update dependencies 2024-06-03 16:40:02 +08:00
520e1adc56 core/builtins: Add np_minimum/np_maximum 2024-05-09 15:01:20 +08:00
73e81259f3 core/builtins: Add np_min/np_max 2024-05-09 15:01:20 +08:00
7627acea41 core/type_inferencer: Fix error message 2024-05-09 15:01:20 +08:00
a777099ea8 core/type_inferencer: Fix missing lowering for some builtin TVars 2024-05-09 15:01:20 +08:00
876e6ea7b8 meta: Update dependencies 2024-05-08 17:27:38 +08:00
30c6cffbad core/builtins: Refactored numpy builtins to accept scalar and ndarrays 2024-05-06 15:38:29 +08:00
51671800b6 core/builtins: Extract codegen portion into functions
We will need to reuse them when implementing elementwise function
application for ndarrays.
2024-05-06 13:21:54 +08:00
7195476edb core/builtins: Add llvm_intrinsics prefix 2024-05-06 13:21:54 +08:00
eecba0b71d core: Add GenCall::create_dummy
A simple abstraction for GenCalls that are already handled elsewhere.
2024-05-06 13:21:54 +08:00
7b4253ccd8 core/numpy: Add missing lifetime parameters 2024-05-06 13:21:54 +08:00
f58c3a11f8 core/builtins: Rework handling of PrimitiveStore-Unifier tuples 2024-05-06 13:21:54 +08:00
d0766a116f core: Remove Box from GenCallCallback type alias
So that references to the function type can be taken.
2024-05-06 13:21:54 +08:00
64a3751fc2 core: Remove custom function type definitions for ndarray operators 2024-05-06 13:21:54 +08:00
9566047241 standalone: Fix cbrt never tested 2024-05-06 13:21:54 +08:00
062e318dd5 core/magic_methods: Fix clippy warnings 2024-05-06 13:21:54 +08:00
c4dc36ae99 standalone: Add explicit -- for delimiting run args vs NAC3 args 2024-05-06 13:21:54 +08:00
baac348ee6 meta: Update dependencies 2024-05-06 13:21:37 +08:00
847615fc2f core: Implement numpy.matmul for 2D-2D ndarrays 2024-04-23 10:27:37 +08:00
5dfcc63978 core/classes: Take reference of indexes 2024-04-16 17:20:24 +08:00
025b3cd02f core/stmt: Remove gen_if_chained*
Turns out it is really difficult to get lifetimes and closures right, so
let's just provide the most rudimentary if-else codegen and we can nest
them if necessary.
2024-04-16 17:16:50 +08:00
e0f440040c core/expr: Implement negative indices for ndarray 2024-04-15 12:49:42 +08:00
f0715e2b6d core/stmt: Add gen_if* functions
For generating if-constructs in IR.
2024-04-15 12:20:34 +08:00
e7fca67786 core/stmt: Do not generate jumps if bb is already terminated
Future-proofs gen_*_callback functions in case other codegen functions
will delegate to it in the future.
2024-04-15 12:20:34 +08:00
52c731c312 core: Implement Not/UAdd/USub for booleans
Not sure if this is deliberate or an oversight, but we implement it
anyway for consistency with other Python implementations.
2024-04-12 18:29:58 +08:00
00d1b9be9b core: Fix __inv__ for i8-based boolean operands 2024-04-12 15:35:54 +08:00
8404d4c4dc meta: Update dependencies 2024-04-12 15:29:09 +08:00
e614dd4257 core/type_inferencer: Fix location of unary/compare expressions
Codegen uses this location information to determine the CallId, and if
a function call is the operand of a unary expression or left-hand
operand of a compare expression, codegen will use the type of the
operator expression rather than the actual operand type.
2024-04-05 15:42:10 +08:00
937a8b9698 core/magic_methods: Fix type of unary ops with primitive types 2024-04-05 13:23:08 +08:00
876ad6c59c core/type_inferencer: Include location info if inferencer fails 2024-04-05 13:22:35 +08:00
a920fe0501 core: Implement elementwise comparison operators 2024-04-03 00:07:33 +08:00
727a1886b3 core: Implement elementwise unary operators 2024-04-03 00:07:33 +08:00
6af13a8261 core: Implement elementwise binary operators
Including immediate variants of these operators.
2024-04-03 00:07:33 +08:00
3540d0ab29 core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with
primitive operands.
2024-04-03 00:07:33 +08:00
3a6c53d760 core/toplevel/numpy: Split ndarray type var utilities 2024-04-03 00:07:33 +08:00
87bc34f7ec core: Implement calculations for broadcasting ndarrays 2024-04-03 00:07:31 +08:00
f50a5f0345 core/type_inferencer: Allow both int32 and isize when indexing ndarray 2024-04-02 16:49:12 +08:00
a77fd213e0 core/magic_methods: Allow unknown return types
These types can be later inferred by the type inferencer.
2024-04-02 16:49:12 +08:00
8f1497df83 core/helper: Add PrimitiveDefinitionIds::iter 2024-04-02 16:49:12 +08:00
5ca2dbeec8 core/typedef: Add Type::obj_id to replace get_obj_id 2024-04-02 16:49:10 +08:00
9a98cde595 core: Extract codegen portion of gen_*op_expr
This allows *ops to be generated internally using LLVM values as
input. Required in a future change.
2024-04-01 16:48:25 +08:00
5ba8601b39 core: Remove ArrayValue variants of functions
These will be lowered and optimized away later anyways, and we have
ArrayLikeAccessor now.
2024-04-01 16:48:25 +08:00
26a01b14d5 core: Use more typed slices in APIs 2024-04-01 16:48:25 +08:00
d5f4817134 core/builtins: Fix len() on ndarrays 2024-04-01 16:48:24 +08:00
789bfb5a26 core: Fix index-based operations not returning i32 2024-04-01 16:46:45 +08:00
4bb0e60981 core: Apply clippy suggestions 2024-04-01 16:46:41 +08:00
623fcf85af msys2: update 2024-03-25 14:45:36 +08:00
13f06f3e29 core: Refactor VarMap to IndexMap
This is the only Map I can find that preserves insertion order while
also deduplicating elements by key.
2024-03-22 15:51:23 +08:00
f0da9c0283 core: Add ArrayLikeValue
For exposing LLVM values that can be accessed like an array.
2024-03-22 15:51:06 +08:00
2c4bf3ce59 core: Allow unsized CodeGenerator to be passed to some codegen functions
Enables codegen_callback to call these codegen functions as well.
2024-03-22 15:07:28 +08:00
e980f19c93 core: Simplify typed value assertions 2024-03-22 15:07:28 +08:00
cfbc37c1ed core: Add gen_for_callback_incrementing
Simplifies generation of monotonically increasing for loops.
2024-03-22 15:07:28 +08:00
50264e8750 core: Add missing unchecked accessors for NDArrayDimsProxy 2024-03-22 15:07:28 +08:00
1b77e62901 core: Split numpy into codegen and toplevel 2024-03-22 15:07:28 +08:00
fd44ee6887 core: Apply clippy suggestions 2024-03-22 15:07:23 +08:00
c8866b1534 core/classes: Rename get_* functions to remove prefix
As suggested by Rust API Guidelines.
2024-03-21 15:46:10 +08:00
84a888758a core: Rename unsafe functions to unchecked
This is this intended name of the functions.
2024-03-21 15:46:10 +08:00
9d550725b7 meta: Update cargo dependencies 2024-03-21 15:45:26 +08:00
2edc1de0b6 standalone: Update ndarray.py to output all elements in ndarrays 2024-03-07 14:59:13 +08:00
c3b122acfc core: Implement ndarray.copy 2024-03-07 14:59:13 +08:00
a94927a11d core: Update __builtin_assume expressions
No dimension size should be 0.
2024-03-07 14:59:13 +08:00
ebf86cd134 core: Use size_t for accessing array elements 2024-03-07 14:59:13 +08:00
cccd8f2d00 core: Fix ndarray_eye not preserving signness of offset 2024-03-07 14:59:13 +08:00
3292aed099 core: Fix ndarray subscript operator returning the wrong object
Should be returning the newly created object instead of the original
ndarray...
2024-03-07 14:59:13 +08:00
96b7f29679 core: Implement ndarray.fill 2024-03-07 14:59:13 +08:00
3d2abf73c8 core: Replace ndarray_init_dims IRRT impl with IR impl
Implementation of that function in IR allows for more flexibility in
terms of different integer type widths.
2024-03-07 14:59:13 +08:00
f682e9bf7a core: Match IRRT compile flavor with build profile 2024-03-07 14:59:02 +08:00
b26cb2b360 core: Express member func def IDs as offsets from class def ID 2024-03-06 12:24:39 +08:00
2317516cf6 core: Use tvars from ndarray for class definition 2024-03-04 23:58:02 +08:00
77de24ef74 core: Use BTreeMap for type variable mapping
There have been multiple instances where I had the need to iterate over
type variables, only to discover that the traversal order is arbitrary.

This commit fixes that by adding SortedMapping, which utilizes BTreeMap
internally to guarantee a traversal order. All instances of VarMap are
now refactored to use this to ensure that type variables are iterated in
 the order of its variable ID, which should be monotonically incremented
 by the unifier.
2024-03-04 23:56:04 +08:00
234a6bde2a core: Use TObj for NDArray 2024-03-01 15:41:55 +08:00
c3db6297d9 core: Add primitive definition-id list
So that we have a single ground truth for the definition IDs of
primitive types.
2024-03-01 11:29:10 +08:00
82fdb02d13 core: Extract LLVM intrinsic functions to their functions 2024-02-23 15:41:06 +08:00
4efdd17513 core: Add missing From implementations for LLVM wrapper classes 2024-02-23 15:41:06 +08:00
49de81ef1e core: Apply clippy suggestions 2024-02-23 15:41:06 +08:00
8492503af2 core: Update cargo dependencies 2024-02-23 15:41:04 +08:00
e1dbe2526a flake: switch to nixpkgs unstable for newer rustc 2024-02-20 15:46:51 +08:00
f37de381ce update dependencies 2024-02-20 13:33:20 +08:00
4452c8986a update ARTIQ version used for PGO profiling 2024-02-20 13:29:00 +08:00
22e831cb76 core: Add test for indexing into ndarray 2024-02-19 17:13:10 +08:00
cc538d221a core: Implement codegen for indexing into ndarray 2024-02-19 17:13:09 +08:00
0d5c53e60c core: Implement type inference for indexing into ndarray 2024-02-19 17:13:09 +08:00
976a9512c1 core: Add const variants to NDArray element getters 2024-02-19 16:56:21 +08:00
1eacaf9afa core: Fix IRRT argument order to ndarray_flatten_index 2024-02-19 16:37:13 +08:00
8c7e44098a core: Fix IRRT implementation of ndarray_flatten_index 2024-02-19 16:37:13 +08:00
282a3e1911 core: Fix typo in error message 2024-02-14 16:26:13 +08:00
5cecb2bb74 core: Fix Literal use in variable type annotation 2024-02-06 18:16:14 +08:00
1963c30744 core: Use Display output for locations 2024-02-06 18:11:51 +08:00
27011f385b core: Add location to non-primitive value return error 2024-02-02 12:49:21 +08:00
d6302b6ec8 core: Allow tuple of primitives to be returned 2024-02-02 12:48:52 +08:00
fef4b2a5ce standalone: Disable tests requiring return of non-primitive values 2024-01-29 12:49:50 +08:00
b3736c3e99 core: Disallow returning of non-primitive values
Non-primitive values are represented by an `alloca`-ed value in the
function body, and when the pointer is returned from the function, the
`alloca`-ed object is deallocated on the stack.

Related to #54.
2024-01-29 12:49:24 +08:00
e328e44c9a update MSYS2 2024-01-26 15:55:45 +08:00
9e4e90f8a0 update dependencies 2024-01-26 15:52:48 +08:00
8470915809 core: Add NDArrayValue and helper functions 2024-01-25 15:51:39 +08:00
148900302e core: Add RangeValue and helper functions 2024-01-25 15:51:39 +08:00
5ee08b585f core: Add ListValue and helper functions 2024-01-25 15:51:39 +08:00
f1581299fc core: Minor changes to IRRT
Add missing documentation, remove redundant lifetime variables, and fix
typos.
2024-01-25 15:50:53 +08:00
af95ba5012 standalone: Add debug flag to run_demo.sh
Allows running demos using the debug build instead of the (default)
release build.
2024-01-25 15:50:53 +08:00
9c9756be33 standalone: Use size_t in demo.c 2024-01-25 15:50:53 +08:00
2a922c7480 artiq: Fix source module of NDArray
Should be `numpy.typing` instead of `numpy`.
2024-01-17 10:40:08 +08:00
e3e2c36ef4 core: Mark TNDArray and TLiteral as unimplemented in tests 2024-01-17 09:58:14 +08:00
4f9a0110c4 meta: Update insta snapshots 2024-01-17 09:49:50 +08:00
12c0eed0a3 core: Fix compilation of tests 2024-01-17 09:49:49 +08:00
c679474f5c standalone: Fix redefinition of ndarray consumer functions 2024-01-17 09:38:13 +08:00
ab3fa05996 demo: use portable format strings 2024-01-10 18:35:35 +08:00
140f8f8a08 core: Implement most ndarray-creation functions 2023-12-22 16:29:55 +08:00
27fcf8926e core: Implement ndarray constructor and numpy.empty 2023-12-22 16:29:54 +08:00
afa7d9b100 core: Implement helper for creation of generic ndarray 2023-12-21 15:39:49 +08:00
c395472094 core: Initial infrastructure for ndarray 2023-12-21 15:39:46 +08:00
03870f222d core: Extract special method handling in type inferencer
To prepare for more special handling with methods.
2023-12-21 15:38:26 +08:00
e435b25756 core: Allow implicit promotions of integral literals
It should not matter, since it is the value of the literal that matters
with respect to the const generic variable.
2023-12-21 15:21:08 +08:00
bd792904f9 core: Add size_t to primitive store
Used for ndims in ndarray.
2023-12-21 15:20:31 +08:00
1c3a823670 core: Do not discard value names for IRRT 2023-12-20 15:16:02 +08:00
f01d833d48 standalone: Add missing parenthesis 2023-12-20 15:15:47 +08:00
9d64e606f4 core: Reject multiple literal bounds
This is currently broken due to how we handle function calls in the
unifier.
2023-12-18 10:04:25 +08:00
6dccb343bb Revert "core: Do not keep unification result for function arguments"
This reverts commit f09f3c27a5.
2023-12-18 10:01:23 +08:00
d47534e2ad interpret_demo: add typing.Literal 2023-12-18 08:50:49 +08:00
8886964776 core: Remove redundant argument in type annotation parsing 2023-12-16 18:40:48 +08:00
f09f3c27a5 core: Do not keep unification result for function arguments
For some reason, when unifying a function call parameter with an
argument, subsequent calls to the same function will only accept the
type of the substituted argument.

This affect snippets like:

```
def make1() -> C[Literal[1]]:
    return ...

def make2() -> C[Literal[2]]:
    return ...

def consume(instance: C[Literal[1, 2]]):
    pass

consume(make1())
consume(make2())
```

The last statement will result in a compiler error, as the parameter of
consume is replaced with C[Literal[1]].

We fix this by getting a snapshot before performing unification, and
restoring the snapshot after unification succeeds.
2023-12-16 18:40:48 +08:00
0bbc9ce6f5 core: Deduplicate values in Literal
Matches the behavior with `typing.Literal`.
2023-12-16 18:40:48 +08:00
457d3b6cd7 core: Refactor generic constants to Literal
Better matches the syntax of `typing.Literal`.
2023-12-16 18:40:48 +08:00
5f692debd8 core: Add PrimitiveStore into Unifier
This will be used during unification between a const generic variable
and a `Literal`.
2023-12-16 18:40:48 +08:00
c7735d935b standalone: Output id of undefined identifier 2023-12-16 18:40:48 +08:00
b47ac1b89b core: Minor formatting cleanup 2023-12-15 17:46:44 +08:00
a19f1065e3 meta: Refactor to use more let-else bindings 2023-12-12 16:31:14 +08:00
5bf05c6a69 update ARTIQ version used for PGO profiling 2023-12-12 15:57:48 +08:00
32746c37be core: Refactor to return errors by HashSet 2023-12-12 15:41:59 +08:00
1d6291b9ba ast: Add Ord implementation to Location 2023-12-12 15:41:59 +08:00
16655959f2 meta: Update cargo dependencies 2023-12-12 15:41:59 +08:00
beee3e1f7e artiq: Pass artiq builtins to NAC3 constructor 2023-12-12 11:28:03 +08:00
d4c109b6ef core: Add missing generic constant concrete type 2023-12-12 11:28:01 +08:00
5ffd06dd61 core: Remove debugging statement 2023-12-12 11:23:51 +08:00
95d0c3c93c artiq: Rename const_generic_dummy to const_generic_marker 2023-12-12 11:23:51 +08:00
bd3d67f3d6 artiq: Apply clippy pedantic changes 2023-12-11 15:16:23 +08:00
ddfb532b80 standalone: Apply clippy pedantic changes 2023-12-11 15:16:23 +08:00
02933753ca core: Apply clippy pedantic changes 2023-12-11 15:16:23 +08:00
a1f244834f meta: Bringup some documentation 2023-12-11 15:16:23 +08:00
d304afd333 meta: Apply clippy suggested changes 2023-12-11 15:16:23 +08:00
ef04696b02 meta: Lift return out of conditional statement 2023-12-11 15:16:23 +08:00
4dc5dbb856 meta: Replace equality assertion with assert_eq
Emits a more useful assertion message.
2023-12-11 15:16:23 +08:00
fd9f66b8d9 meta: Remove redundant casts and brackets 2023-12-11 15:16:23 +08:00
5182453bd9 meta: Remove redundant path prefixes 2023-12-11 15:16:23 +08:00
68556da5fd update ARTIQ version used for PGO profiling 2023-12-11 09:37:03 +08:00
983f080ea7 artiq: Implement handling for const generic variables 2023-12-08 18:02:14 +08:00
031e660f18 core: Initial implementation for const generics 2023-12-08 18:02:11 +08:00
b6dfcfcc38 core: Move some SymbolValue functions to symbol_resolver.rs 2023-12-08 18:00:51 +08:00
c93ad152d7 core: Codegen for ellipsis expression as NotImplemented
A lot of refactoring was performed, specifically with relaxing
expression codegen to return Option in case where ellipsis are used
within a subexpression.
2023-12-08 18:00:51 +08:00
68b97347b1 core: Infer builtins name list using builtin declaration list 2023-12-08 17:29:34 +08:00
875d534de4 ast: Use {filename}:{row}:{col} for location output 2023-12-08 15:48:54 +08:00
adadf56e2b nac3standalone: generate PIC 2023-12-04 19:09:50 +08:00
9f610745b7 cargo: update dependencies 2023-12-04 18:51:06 +08:00
98199768e3 demo: fix 64-bit format strings 2023-12-04 18:51:06 +08:00
bfa9ceaae3 switch to new nixpkgs release 2023-12-03 10:31:05 +08:00
120f8da5c7 fix compilation warnings 2023-11-26 09:09:24 +08:00
cee62aa6c5 pin down LLVM used for IRRT 2023-11-25 20:15:29 +08:00
fcda360ad6 flake: update dependencies 2023-11-24 18:11:25 +08:00
87c20ada48 windows: switch to CLANG64 MSYS2
For compatibility with MSVC (Anaconda and others).
2023-11-24 18:10:00 +08:00
38e968cff6 gitignore: fix msys2 path 2023-11-24 17:18:17 +08:00
5c5620692f core: Add np_{round,floor,ceil}
These functions are NumPy variants of round/floor/ceil, which returns
floats instead of ints.
2023-11-23 13:45:07 +08:00
0af1e37e99 core: Prefix all NumPy/SciPy functions with np_/sp_spec 2023-11-23 13:35:23 +08:00
854e33ed48 meta: Update cargo dependencies 2023-11-23 13:31:24 +08:00
f020d61cbb update ARTIQ version used for PGO profiling 2023-11-11 11:10:58 +08:00
10538b5296 core: Update insta snapshots 2023-11-09 13:00:27 +08:00
d322c91697 core: Change bitshift operators to accept int32/uint32 for RHS operand 2023-11-09 12:16:20 +08:00
3231eb0d78 core: Add compile-time error and runtime assertion for negative shifts 2023-11-09 12:16:20 +08:00
1ca4de99b9 update ARTIQ version used for PGO profiling 2023-11-08 17:29:29 +08:00
bf4b1aae47 update dependencies 2023-11-08 17:23:49 +08:00
08a5050f9a core: Implement non-trivial builtin functions using IRRT 2023-11-06 12:57:23 +08:00
c2ab6b58ff artiq: Implement with legacy_parallel block 2023-11-04 13:42:44 +08:00
0a84f7ac31 Add CodeGenerator::gen_block and refactor to use it 2023-11-04 13:42:44 +08:00
fd787ca3f5 core: Remove trunc
The behavior of trunc is already implemented by casts and is therefore
redundant.
2023-11-04 13:35:53 +08:00
4dbe07a0c0 core: Revert breaking changes to round-family functions
These functions should return ints as the math.* functions do instead of
following the convention of numpy.* functions.
2023-11-04 13:35:53 +08:00
2e055e8ab1 core: Replace rint implementation with LLVM intrinsic 2023-11-04 13:35:53 +08:00
9d737743c1 standalone: Add regression test for numeric primitive operations 2023-11-03 16:24:26 +08:00
c6b9aefe00 core: Fix int32-to-uint64 conversion
This conversion should be sign-extended.
2023-11-03 16:24:26 +08:00
8ad09748d0 core: Fix conversion from float to unsigned types
These conversions also need to wraparound.
2023-11-03 16:24:26 +08:00
7a5a2db842 core: Fix handling of float-to-int32 casts
Out-of-bound conversions should be wrapped around.
2023-11-03 16:24:26 +08:00
447eb9c387 standalone: Fix output format string for output_uint* 2023-11-03 16:24:26 +08:00
92d6f0a5d3 core: Implement bitwise not for unsigned ints and fix implementation 2023-11-03 16:24:26 +08:00
7e4dab15ae standalone: Add math tests for non-number arguments 2023-11-01 18:03:29 +08:00
ff1fed112c core: Rework gamma/gammaln to match SciPy behavior
Matches behavior for infinities and NaNs.
2023-11-01 18:03:29 +08:00
36a6a7b8cd core: Replace TopLevelDef comments with documentation 2023-11-01 18:03:29 +08:00
2b635a0b97 core: Implement numpy and scipy functions 2023-11-01 18:03:29 +08:00
60ad100fbb core: Implement and expose {isinf,isnan} 2023-11-01 18:03:29 +08:00
316f0824d8 flake: Add scipy 2023-11-01 18:03:29 +08:00
7cf7634985 core: Add create_fn_by_* functions
Used for abstracting the creation of function from different sources.
2023-11-01 18:03:29 +08:00
068f0d9faf core: Do not cast floor/ceil result to int
NumPy explicitly states that the return type of the floor/ceil is float.
2023-11-01 18:03:29 +08:00
95810d4229 core: Remove {ceil64,floor64,round,round64}
These are not present in NumPy or Artiq.
2023-11-01 18:03:29 +08:00
630897b779 standalone: Do not output sign if float is NaN
Matches behavior in Python.
2023-11-01 18:03:29 +08:00
e546535df0 flake: update nixpkgs 2023-11-01 15:53:47 +08:00
352f70b885 artiq: Update host exception list to match possibly thrown types 2023-11-01 13:28:48 +08:00
e95586f61e core: Fix IR generation of for loop containing break/continue
Fix cases where the body BB would have two terminators because of a
preceding continue/break statement already emitting a terminator.
2023-11-01 13:21:27 +08:00
bb27e3d400 standalone: Fix indentation of demo.c 2023-11-01 13:20:26 +08:00
bb5147521f standalone: Fix indentation of test files 2023-11-01 13:20:26 +08:00
9518d3fe14 artiq: Fix timeline not resetting upon exiting sequential block 2023-10-30 14:04:53 +08:00
cbd333ab10 artiq: Extract parallel block timeline utilities 2023-10-30 14:04:53 +08:00
65d6104d00 artiq: Improve IR value naming and add documentation 2023-10-30 14:04:53 +08:00
8373a6cb0f artiq: Use gen_block when generating "with sequential" 2023-10-30 14:04:53 +08:00
f75ae78677 cargo: Update dependencies 2023-10-30 14:04:53 +08:00
ea2ab0ef7c update nixpkgs, python 3.11 2023-10-25 21:09:22 +08:00
e49b760e34 ld: Support multiple CFIs with different encoding in .eh_frame
We now parse each CFI to read its encoding as opposed to assuming that
all CFIs within the same EH_Frame uses the same encoding. FDEs are now
iterated in a per-CFI manner.
2023-10-20 18:15:03 +08:00
aa92778363 ld: Fix remapping of FDEs with multiple CFIs 2023-10-20 18:14:27 +08:00
e1487ed335 cargo: Update dependencies 2023-10-20 18:11:45 +08:00
73500c9081 core: Remove lazy_static from dependencies 2023-10-16 15:55:10 +08:00
9ca34c714e flake: Enable thread-safe mode for LLVM
This is required as we use the LLVM APIs from multiple threads.
2023-10-16 15:55:10 +08:00
7fc2a30c14 Force single-threaded compilation if LLVM is not thread-safe 2023-10-16 15:55:10 +08:00
950f431483 standalone: Update help text for --emit-llvm 2023-10-16 15:52:51 +08:00
a50c690428 standalone: Fix run_demo script
- Link main and module*.bc together if using multiple threads
- Fix temporary files not being deleted
2023-10-16 15:52:48 +08:00
48eb64403f standalone: Treat -T0 as using all available threads 2023-10-13 14:57:16 +08:00
2c44b58bb8 standalone: Require use of -T for specifying thread count 2023-10-13 14:36:34 +08:00
50230e61f3 core: Simplify loop condition check for list comprehension 2023-10-06 12:24:03 +08:00
0205161e35 core: Simplify list creation for comprehension 2023-10-06 12:22:38 +08:00
a2fce49b26 core: Allocate exceptions at the beginning of function
Only one instance of exception is necessary, as exceptions will always
be initialized before being thrown.
2023-10-06 12:13:20 +08:00
60a503a791 core: Allocate more stack variables at the beginning of function
All allocas for temporary objects are now placed in the beginning of the
function. Allocas for on-temporary objects are not modified because
these variables may appear in a loop and thus must be uniquely
allocated by different allocas.
2023-10-06 11:42:47 +08:00
0c49b30a90 core: Restore debug info before function call is invoked
Previously, the IR which sets up the call to the target function will
have its debug location pointing at the last argument of the function
call instead of the function call itself.
2023-10-06 11:35:23 +08:00
c7de22287e core: Fix restoration of stack address
All allocas for temporary objects are now placed in the beginning of the
function. Allocas for on-temporary objects are not modified because
these variables may appear in a loop and thus must be uniquely
represented.
2023-10-06 11:34:23 +08:00
1a54aaa1c0 core: Restore debug location when generating allocas
Debug location is lost when moving the builder cursor.
2023-10-06 11:11:50 +08:00
c5629d4eb5 standalone: Remove redundant const in demo library 2023-10-06 10:32:58 +08:00
a79286113e standalone: Add output_bool in demo library 2023-10-06 10:19:22 +08:00
901e921e00 windows: fix build 2023-10-05 18:02:53 +08:00
45a323e969 windows: update msys2 packages 2023-10-05 17:52:29 +08:00
11759a722f flake: fix pgo build 2023-10-05 17:38:36 +08:00
480a4bc0ad core: Implement comparison operators for unsigned types 2023-10-05 17:13:10 +08:00
a1d3093196 flake: update dependencies 2023-10-05 17:05:57 +08:00
85c5f2c044 cargo: update dependencies 2023-10-05 17:03:35 +08:00
f34c6053d6 standalone: Add flags to control demo output options 2023-10-04 18:11:44 +08:00
e8a5f0dfef standalone: Fix parsing NAC3 args in check_demo.sh 2023-10-04 18:03:28 +08:00
7140901261 standalone: Fix missing libraries when linking
Fixes `undefined reference to 'pow'` for pow.py using -O0.
2023-10-04 18:03:28 +08:00
2a775d822e core: Demote dead code into a stdout warning 2023-10-04 18:03:25 +08:00
1659c3e724 standalone: Remove temporary logfiles after execution 2023-09-30 09:31:18 +08:00
f53cb804ec standalone: Add execution of test cases via lli 2023-09-30 09:31:18 +08:00
279376a373 standalone: Emit IRRT IR 2023-09-30 09:31:18 +08:00
b6afd1bfda standalone: Split check_demos into check_demo
Allows individual tests to be executed.
2023-09-30 09:31:18 +08:00
be3e8f50a2 standalone: Refactor demo library to C
Needed for use by lli.
2023-09-30 09:31:18 +08:00
059d3da58b standalone: Add float64 output tests 2023-09-30 09:31:18 +08:00
9b28f23d8c flake: Add clang alongside clang-unwrapped 2023-09-30 09:31:18 +08:00
119f4d63e9 cargo: update dependencies 2023-09-29 14:46:22 +08:00
458fa12788 flake: update dependencies 2023-09-29 14:07:47 +08:00
48c6498d1f core: Fix restoration of loop target in try statement
old_loop_target is only assigned if ctx.loop_target is overwritten,
meaning that if ctx.loop_target is never overwritten, ctx.loop_target
will always be overwritten to None.

We fix this by only restoring from old_loop_target if we previously
assigned to old_loop_target.
2023-09-28 20:00:02 +08:00
2a38d5160e meta: Respect opt flags when performing whole-module optimization 2023-09-28 19:58:54 +08:00
b39831b388 standalone: Update demos
- Add `output_str` for printing a string
- Add demo_test.py to test interop
2023-09-28 19:58:53 +08:00
cb39f61e79 core: Fix passing structure arguments to extern functions
All parameters with a structure type in extern functions are marked as
`byref` instead of `byval`, as most ABIs require the first several
arguments to be passed in registers before spilling into the stack.

`byval` breaks this contract by explicitly requiring all arguments to be
 passed in the stack, breaking interop with libraries written in other
 languages.
2023-09-28 15:02:35 +08:00
176f250bdb core: Fix missing conversion to i1 for IfExp 2023-09-28 10:06:40 +08:00
acdb1de6fe meta: Improve documentation for various modified classes 2023-09-25 15:42:07 +08:00
31dcd2dde9 core: Use i8 for boolean variable allocation
In LLVM, i1 represents a 1-byte integer with a single valid bit; The
rest of the 7 upper bits are undefined. This causes problems when
using these variables in memory operations (e.g. memcpy/memmove as
needed by List slicing and assignment).

We fix this by treating all local boolean variables as i8 so that they
are well-defined for memory operations. Function ABIs will continue to
use i1, as memory operations cannot be directly performed on function
arguments or return types, instead they are always converted back into
local boolean variables (which are i8s anyways).

Fixes #315.
2023-09-25 15:42:07 +08:00
fc93fc2f0e core: Move bitcode verification error message into panic message 2023-09-22 17:16:29 +08:00
dd42022633 core: Minor refactor allocate_list 2023-09-22 17:16:29 +08:00
6dfc43c8b0 core: Add name to build_gep_and_load 2023-09-22 17:16:29 +08:00
ab2360d7a0 core: Remove emit_llvm from CodeGenLLVMOptions
We instead output an LLVM bitcode file when the option is specified on
the command-line.
2023-09-22 17:16:29 +08:00
ee1ee4ab3b core: Replace deprecated _ExtInt with _BitInt 2023-09-22 17:16:29 +08:00
3e430b9b40 core: Fix missing changes for codegen tests
Apparently the changes were dropped after rebasing.
2023-09-22 17:16:21 +08:00
9e57498958 meta: Update dependencies 2023-09-21 09:38:38 +08:00
769fd01df8 meta: Allow specifying compiler arguments for check_demos 2023-09-18 11:35:20 +08:00
411837cacd artiq: Specify target CPU when creating LLVM target options
We can try to optimize for the host and Cortex-A9 chips; The RISC-V
ISAs do not target specific chips, so we will fallback to using the
generic CPU.
2023-09-18 11:35:20 +08:00
f59d45805f standalone: Add command line flags for target properties
For testing codegen for different platforms on the host system.
2023-09-18 11:35:20 +08:00
048fcb0a69 core: Switch to LLVM New Pass Manager 2023-09-18 11:35:15 +08:00
676d07657a core: Add target field to CodeGenLLVMOptions
For specifying the target machine options when optimizing and linking.

This field is currently unused but will be required in a future
commit.
2023-09-18 09:46:24 +08:00
2482a1ef9b core: Add CodeGenTargetMachineOptions
Needed in a future commit.
2023-09-18 09:41:49 +08:00
eb63f2ad48 meta: Update to Rust Edition 2021 2023-09-15 10:25:50 +08:00
ff27e22ee6 flake: switch back to nixpkgs unstable
Too many issues with python-updates branch for now.
2023-09-13 19:15:47 +08:00
d672ef094b msys2: update packages, Python 3.11 2023-09-13 09:50:33 +08:00
d25921230e switch to Python 3.11 2023-09-13 09:44:08 +08:00
66f07b5bf4 flake: switch to nixos-unstable 2023-09-12 18:14:39 +08:00
008d50995c meta: Update run_demo.sh
- Allow more than one argument to nac3standalone executable
2023-09-12 16:20:50 +08:00
474f9050ce standalone: Expose flags in command-line 2023-09-12 16:20:49 +08:00
3993a5cf3f core: Add LLVM options to WorkerRegistry 2023-09-12 10:57:05 +08:00
39724de598 core: Add CodeGenLLVMOptions
For specifying LLVM options during code generation.
2023-09-12 10:57:04 +08:00
e4940247f3 standalone: Implement command-line parser using clap
In preparation for adding more command-line options.
2023-09-12 10:08:34 +08:00
4481d48709 core: Use C-style for loop logic for iterables
Index increment is now performed at the end of the loop body.
2023-09-06 20:09:38 +08:00
b4983526bd core: Remove redundant for.cond BB for iterable loops
Simplifies logic for creating basic blocks.
2023-09-06 20:09:37 +08:00
b4a9616648 core: Add assertion for when range has step of 0
Aligns with the behavior in Python.
2023-09-06 20:09:36 +08:00
e0de82993f core: Preserve value of variable shadowed by for loop
Previously, the final value of the target expression would be one after
the last element of the loop, which does not match Python's behavior.

This commit fixes this problem while also preserving the last assigned
value of the loop beyond the loop, matching Python's behavior.
2023-09-06 20:09:36 +08:00
6805253515 core: Use AST var name for IR name
Aids debugging IR.
2023-09-06 20:09:36 +08:00
19915bac79 core: Prepend statement type to basic block label names
Aids debugging IR.
2023-09-06 20:09:36 +08:00
17b4686260 standalone: Adapt loop example to output loop variable 2023-09-06 18:56:45 +08:00
6de0884dc1 core: Use anonymous name for variables if unspecified
The current default prefix is only derived from the instruction type,
which is not helpful during the comprehension of the IR. Changing to
anonymous names (e.g. %1) helps understand that the variable is only
needed as part of a larger (possibly named) expression.
2023-09-06 14:02:15 +08:00
f1b0e05b3d core: Rename IR variables
Because it is unclear which variables are expressions and
subexpressions, all variables which are previously anonymous are named
using (1) the control flow statement if available, (2) the possible name
of the variable as inferred from the variable name in Rust, and (3) the
"addr" prefix to indicate that the values are pointers. These three
strings are joint together using '.', forming "for.i.addr" for instance.
2023-09-06 14:02:15 +08:00
ff23968544 core: Add name parameter to gen_{var_alloc,store_target}
This allows variables in the IR to be assigned a custom name as opposed
to names with a default prefix.
2023-09-06 11:00:02 +08:00
049908044a flake: update dependencies 2023-09-04 11:00:15 +08:00
d37287a33d Cargo: Update dependencies 2023-09-04 10:43:57 +08:00
283bd7c69a cargo: update dependencies 2023-07-14 10:57:21 +08:00
3d73f5c129 flake: update dependencies 2023-07-10 13:46:00 +08:00
d824c5d8b5 flake: cleanup dev shells 2023-05-30 16:28:46 +08:00
b8d637f5c4 cargo: update dependencies 2023-05-27 18:56:21 +08:00
3af287d1c4 flake: nixpkgs 23.05 2023-05-27 18:14:55 +08:00
5b53be0311 update dependencies 2023-04-30 17:11:47 +08:00
aead36f0fd update dependencies 2023-03-08 15:19:09 +08:00
c269444c0b msys2: update packages 2023-01-14 16:09:21 +08:00
52cec3c12f msys2: nix store doesn't like tildes 2023-01-14 16:09:00 +08:00
2927f2a1d0 msys2: adapt to recent pacman 2023-01-14 16:08:39 +08:00
c1c45373a6 update dependencies 2023-01-12 19:31:03 +08:00
946ea155b8 flake: switch to NixOS release 2022-11-30 11:37:48 +08:00
085c6ee738 update dependencies 2022-11-18 16:15:46 +08:00
cfa67c418a update MSYS2 URL 2022-11-03 19:00:44 +08:00
813bfa92a7 windows: fix nac3artiq module installation path 2022-08-05 22:42:32 +08:00
fff4b65169 windows: parallel LLVM build 2022-08-05 18:24:00 +08:00
c891fffd75 windows: update msys2, python 3.10 2022-08-05 17:27:07 +08:00
12acd35e15 switch to nixpkgs master, python 3.10 2022-08-05 17:24:49 +08:00
f66ca02b2d update Rust dependencies 2022-08-05 16:58:57 +08:00
b514f91441 nac3artiq: inherit kernel constructors
Closes #139
Co-authored-by: z78078 <cc@m-labs.hk>
Co-committed-by: z78078 <cc@m-labs.hk>
2022-07-28 19:18:36 +08:00
8f95b79257 nac3artiq: throw error message when constructor use rpc decorator (#306)
Co-authored-by: z78078 <cc@m-labs.hk>
Co-committed-by: z78078 <cc@m-labs.hk>
2022-07-11 15:55:55 +08:00
ebd25af38b nac3standalone: allow classes without explicit init (#221)
Reviewed-on: M-Labs/nac3#304
Co-authored-by: z78078 <cc@m-labs.hk>
Co-committed-by: z78078 <cc@m-labs.hk>
2022-07-07 10:36:25 +08:00
96b3a3bf5c work around random segmentation fault (#275)
Reviewed-on: M-Labs/nac3#302
Co-authored-by: z78078 <cc@m-labs.hk>
Co-committed-by: z78078 <cc@m-labs.hk>
2022-07-04 18:06:36 +08:00
a18d095245 nac3core: codegen fix call parameter type error 2022-07-04 14:39:33 +08:00
b242463548 update dependencies 2022-07-02 19:04:19 +08:00
8e6e4d6715 README: call for Nix 2.8 (older versions have flake bugs) 2022-06-06 18:14:21 +08:00
73c2aefe4b README: mention nac3ld 2022-06-06 18:13:21 +08:00
892597cda4 update dependencies 2022-06-06 17:54:23 +08:00
33321c5e9c README,nix: remove lld 2022-06-06 17:50:32 +08:00
50ed04b787 nac3ld: replace unsafe code 2022-06-06 14:41:14 +08:00
7cb9be0f81 nac3artiq: refactor compile methods
Avoids writing relocatable object to a file when linking with nac3ld.
2022-06-06 14:41:10 +08:00
ac560ba985 nac3artiq: switch ld.lld to nac3ld for non-host target 2022-06-06 14:40:13 +08:00
a96371145d add nac3ld 2022-06-06 14:40:13 +08:00
8addf2b55e nac3standalone: add more tests 2022-06-01 17:58:16 +08:00
5d5e9a5e02 nac3core: fix codegen error of inheritance 2022-06-01 17:58:16 +08:00
4c39dd240f update all dependencies 2022-05-31 23:09:51 +08:00
48fc5ceb8e nac3artiq: demote global value to private
... except typeinfo & now symbols.
typeinfo will be read by the runtime linker; now is for now-pinning.
2022-05-30 22:46:41 +08:00
c4ab2855e5 nac3core: pretty print codegen panic error 2022-05-30 04:09:21 +08:00
ffac37dc48 nac3core: fix exception type in primitive store 2022-05-29 19:14:00 +08:00
76473152e8 nac3core: fix llvm.expect intrinsic name
this might be one of the causes for the random segfault bug
2022-05-27 04:23:49 +08:00
b04631e935 update dependencies, switch to nixpkgs 22.05 2022-05-24 11:10:29 +08:00
09820e5aed nac3artiq: return err instead of panic for host object attribute error 2022-05-18 02:54:42 +08:00
0ec2ed4d91 update dependencies 2022-05-17 12:05:12 +08:00
2cb725b7ac nac3artiq: correct global name for const object 2022-05-16 02:50:42 +08:00
b9259b1907 update nixpkgs and LLVM 2022-05-14 16:33:03 +08:00
096f4b03c0 nac3core: fix assignment 2022-05-14 02:30:08 +08:00
a022005183 nac3core: fix broken tests 2022-05-11 03:53:53 +08:00
325ba0a408 nac3core: add debug info 2022-05-11 03:53:53 +08:00
ae6434696c nac3artiq: rename the filename of modinit
rename from __nac3_synthesized_modinit__ to <nac3_synthesized_modinit> to be more idomatic python
2022-05-11 03:52:16 +08:00
3f327113b2 update dependencies, use upstream inkwell 2022-04-27 15:41:46 +08:00
27d509d70e nac3artiq: get_const_obj should no longer make a pointer. Closes #272 2022-04-27 15:28:58 +08:00
a321b13bec fix typos 2022-04-27 11:08:10 +08:00
48cb485b89 nac3core: show outer type info in type error messages
Reviewed-on: M-Labs/nac3#274
Co-authored-by: ychenfo <yc@m-labs.hk>
Co-committed-by: ychenfo <yc@m-labs.hk>
2022-04-22 15:31:55 +08:00
837aaa95f1 flake: contain sipyco to nac3artiq-profile 2022-04-19 10:34:55 +08:00
a19e9c0bec flake: provide llvm-as for IRRT
clang already depends on llvmPackages_13.llvm, so, unlike the statically-linked tools
from llvm-nac3, this does not make the bloat even worse.
2022-04-19 10:23:41 +08:00
5dbe1d3d7d llvm: restore llvm-config 2022-04-19 10:23:12 +08:00
e9bca3c822 llvm: set LLVM_BUILD_TOOLS=OFF 2022-04-19 00:30:11 +08:00
42d1aad507 flake: add PGO build to Hydra 2022-04-18 23:58:43 +08:00
2777a6e05f flake: use nac3devices example for PGO 2022-04-18 23:57:57 +08:00
05be5e93c4 flake: update nixpkgs 2022-04-18 18:48:05 +08:00
85f21060e4 update to LLVM 14 2022-04-18 18:47:20 +08:00
a308d24caa nac3standalone: cleanup 2022-04-18 16:02:48 +08:00
1eac111d4c cleanup 2022-04-18 15:55:37 +08:00
44199781dc nac3standalone: add tests for operators 2022-04-18 15:31:56 +08:00
711c3d3303 nac3core: support custom operators 2022-04-18 15:31:56 +08:00
0975264482 README: center icon 2022-04-18 15:11:32 +08:00
087aded3a3 add icon
Icon is copyright Evgeny Filatov and not covered by any free software license.
2022-04-18 15:07:53 +08:00
f14b32be67 nac3artiq: type check host int bound instead of panic when codegen 2022-04-16 03:01:37 +08:00
David Nadlinger
879c66cccf flake.nix: Fix outdated nixConfig keys
The old syntax seems to be silently ignored on (at least)
Nix 2.7.0.
2022-04-13 21:21:18 +01:00
wylited
35b6459c58 nac3core: replace paramter with parameter 2022-04-13 15:42:26 +08:00
e94b25f544 spelling (#264)
Co-authored-by: wylited <ds@m-labs.hk>
Co-committed-by: wylited <ds@m-labs.hk>
2022-04-13 11:32:31 +08:00
6972689469 nac3artiq: cleanup demo 2022-04-12 10:34:14 +08:00
3fb22c9182 nac3artiq: treat host numpy.float64 as float. Closes #90 2022-04-12 10:33:28 +08:00
1e7abf0268 fix tests 2022-04-12 10:06:41 +08:00
f5a6d29106 update insta snapshots 2022-04-12 09:56:49 +08:00
ca07cb66cd format typevars consistently 2022-04-12 09:28:17 +08:00
93e9a6a38a update dependencies 2022-04-12 09:13:04 +08:00
722e3df086 nac3core, artiq: optimize kernel invariant for tuple index 2022-04-11 14:58:40 +08:00
ad9ad22cb8 nac3core: optimize unwrap KernelInvariant 2022-04-11 14:58:35 +08:00
f66f66b3a4 nac3core, artiq: remove unnecessary ptr casts 2022-04-10 01:28:46 +08:00
6c485bc9dc nac3artiq: skip attribute writeback for option
option types do not have any fields to be written back to the host so it is ok to skip. If we do not skip, there will be error when getting the value of it since it can be `none`, whose type is not concrete
2022-04-10 01:28:30 +08:00
089bba96a3 nac3artiq: get_obj_value take an additional argument for expected type 2022-04-10 01:28:30 +08:00
0e0871bc38 nac3core, artiq: to_basic_value_enum takes an argument indicating the expected type 2022-04-10 01:28:22 +08:00
26187bff0b nac3core: add missing bound check and negative index handling for list subscription assignment 2022-04-09 04:56:31 +08:00
86ce513cb5 nac3standalone: fix broken test
previously this test unexpectedly passed because it is a slice assignment to extend the list, which is valid in CPython and hence in interpret_demo, and which also happened to give the same output in nac3 by memmove the elements in the list of bool
2022-04-05 18:21:46 +08:00
c29cbf6ddd nac3core: add bound check for list slice 2022-04-05 18:21:46 +08:00
7443c5ea0f nac3core: add location information to codegen context 2022-04-05 18:21:46 +08:00
f55b077e60 README: update Windows instructions 2022-04-05 18:07:38 +08:00
e05b0bf5dc flake: update nixpkgs 2022-04-05 10:10:08 +08:00
8eda0affc9 windows: add wine-msys2-build 2022-04-05 10:06:36 +08:00
75c53b40a3 windows: update msys2 packages, add setuptools to environment 2022-04-05 10:06:14 +08:00
0d10044d66 Merge pull request 'Fix float**int with negative power' (#254) from neg_powi_fix into master
Reviewed-on: M-Labs/nac3#254
2022-04-04 22:43:20 +08:00
23b7f4ef18 nac3standalone: add tests for power 2022-04-04 22:10:56 +08:00
710904f975 nac3core: fix powi with negative integer power 2022-04-04 22:10:56 +08:00
4bf452ec5a windows: do not check dependencies when making package 2022-04-04 16:03:59 +08:00
9fdce11efe windows: depend on python 2022-04-04 15:21:34 +08:00
f24ef85aed hydra: use msys2 type 2022-04-04 15:03:53 +08:00
4a19787f10 README: update 2022-04-04 15:03:44 +08:00
8209c0a475 windows: create MSYS2 package 2022-04-04 14:24:47 +08:00
144 changed files with 38302 additions and 11068 deletions

32
.clang-format Normal file
View File

@ -0,0 +1,32 @@
BasedOnStyle: LLVM
Language: Cpp
Standard: Cpp11
AccessModifierOffset: -1
AlignEscapedNewlines: Left
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: Yes
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortFunctionsOnASingleLine: Inline
BinPackParameters: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: AfterColon
BreakInheritanceList: AfterColon
ColumnLimit: 120
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ContinuationIndentWidth: 4
DerivePointerAlignment: false
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
MaxEmptyLinesToKeep: 1
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SortUsingDeclarations: true
SpaceAfterTemplateKeyword: false
SpacesBeforeTrailingComments: 2
TabWidth: 4
UseTab: Never

1
.clippy.toml Normal file
View File

@ -0,0 +1 @@
doc-valid-idents = ["CPython", "NumPy", ".."]

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
__pycache__ __pycache__
/target /target
windows/msys2 /nac3standalone/demo/linalg/target
nix/windows/msys2

24
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,24 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [commit]
repos:
- repo: local
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [develop, -c, cargo, fmt, --all]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: nix
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [develop, -c, cargo, clippy, --tests]

1214
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
[workspace] [workspace]
members = [ members = [
"nac3ld",
"nac3ast", "nac3ast",
"nac3parser", "nac3parser",
"nac3core", "nac3core",
@ -7,6 +8,7 @@ members = [
"nac3artiq", "nac3artiq",
"runkernel", "runkernel",
] ]
resolver = "2"
[profile.release] [profile.release]
debug = true debug = true

View File

@ -1,5 +1,10 @@
# NAC3 <div align="center">
![icon](https://git.m-labs.hk/M-Labs/nac3/raw/branch/master/nac3.svg)
</div>
# NAC3
NAC3 is a major, backward-incompatible rewrite of the compiler for the [ARTIQ](https://m-labs.hk/artiq) physics experiment control and data acquisition system. It features greatly improved compilation speeds, a much better type system, and more predictable and transparent operation. NAC3 is a major, backward-incompatible rewrite of the compiler for the [ARTIQ](https://m-labs.hk/artiq) physics experiment control and data acquisition system. It features greatly improved compilation speeds, a much better type system, and more predictable and transparent operation.
NAC3 has a modular design and its applicability reaches beyond ARTIQ. The ``nac3core`` module does not contain anything specific to ARTIQ, and can be used in any project that requires compiling Python to machine code. NAC3 has a modular design and its applicability reaches beyond ARTIQ. The ``nac3core`` module does not contain anything specific to ARTIQ, and can be used in any project that requires compiling Python to machine code.
@ -8,7 +13,7 @@ NAC3 has a modular design and its applicability reaches beyond ARTIQ. The ``nac3
## Packaging ## Packaging
NAC3 is packaged using the [Nix](https://nixos.org) Flakes system. Install Nix 2.4+ and enable flakes by adding ``experimental-features = nix-command flakes`` to ``nix.conf`` (e.g. ``~/.config/nix/nix.conf``). NAC3 is packaged using the [Nix](https://nixos.org) Flakes system. Install Nix 2.8+ and enable flakes by adding ``experimental-features = nix-command flakes`` to ``nix.conf`` (e.g. ``~/.config/nix/nix.conf``).
## Try NAC3 ## Try NAC3
@ -16,44 +21,21 @@ NAC3 is packaged using the [Nix](https://nixos.org) Flakes system. Install Nix 2
After setting up Nix as above, use ``nix shell git+https://github.com/m-labs/artiq.git?ref=nac3`` to get a shell with the NAC3 version of ARTIQ. See the ``examples`` directory in ARTIQ (``nac3`` Git branch) for some samples of NAC3 kernel code. After setting up Nix as above, use ``nix shell git+https://github.com/m-labs/artiq.git?ref=nac3`` to get a shell with the NAC3 version of ARTIQ. See the ``examples`` directory in ARTIQ (``nac3`` Git branch) for some samples of NAC3 kernel code.
### Windows (work in progress) ### Windows
NAC3 ARTIQ packaging for MSYS2/Windows is not yet complete so installation involves many manual steps. It is also less tested and you may encounter problems. Install [MSYS2](https://www.msys2.org/), and open "MSYS2 CLANG64". Edit ``/etc/pacman.conf`` to add:
Install [MSYS2](https://www.msys2.org/) and run the following commands:
``` ```
pacman -S mingw-w64-x86_64-python-h5py mingw-w64-x86_64-python-pyqt5 mingw-w64-x86_64-python-scipy mingw-w64-x86_64-python-prettytable mingw-w64-x86_64-python-pygit2 [artiq]
pacman -S mingw-w64-x86_64-python-pip SigLevel = Optional TrustAll
pip install qasync Server = https://msys2.m-labs.hk/artiq-nac3
pip install pyqtgraph
pacman -S patch git
git clone https://github.com/m-labs/sipyco
cd sipyco
git show 20c946aad78872fe60b78d9b57a624d69f3eea47 | patch -p1 -R
python setup.py install
cd ..
git clone -b nac3 https://github.com/m-labs/artiq
cd artiq
python setup.py install
``` ```
Locate a recent build of ``nac3artiq-msys2`` from [Hydra](https://nixbld.m-labs.hk) and download ``nac3artiq.zip``. Then extract the contents in the appropriate location: Then run the following commands:
``` ```
pacman -S unzip pacman -Syu
wget https://nixbld.m-labs.hk/build/115529/download/1/nac3artiq.zip # edit the build number pacman -S mingw-w64-clang-x86_64-artiq
unzip nac3artiq.zip -d C:/msys64/mingw64/lib/python3.9/site-packages
``` ```
Do the same for ``lld-msys2``:
```
wget https://nixbld.m-labs.hk/build/115527/download/1/ld.lld.exe
mv ld.lld.exe C:/msys64/mingw64/bin
```
And you should be good to go.
Note: This build of NAC3 cannot be used with Anaconda Python nor the python.org binaries for Windows. Those Python versions are compiled with Visual Studio (MSVC) and their ABI is incompatible with the GNU ABI used in this build. We have no plans to support Visual Studio nor the MSVC ABI. If you need a MSVC build, please install the requisite bloated spyware from Microsoft and compile NAC3 yourself.
## For developers ## For developers
This repository contains: This repository contains:
@ -61,6 +43,7 @@ This repository contains:
- ``nac3parser``: Python parser (based on RustPython). - ``nac3parser``: Python parser (based on RustPython).
- ``nac3core``: Core compiler library, containing type-checking and code generation. - ``nac3core``: Core compiler library, containing type-checking and code generation.
- ``nac3standalone``: Standalone compiler tool (core language only). - ``nac3standalone``: Standalone compiler tool (core language only).
- ``nac3ld``: Minimalist RISC-V and ARM linker.
- ``nac3artiq``: Integration with ARTIQ and implementation of ARTIQ-specific extensions to the core language. - ``nac3artiq``: Integration with ARTIQ and implementation of ARTIQ-specific extensions to the core language.
- ``runkernel``: Simple program that runs compiled ARTIQ kernels on the host and displays RTIO operations. Useful for testing without hardware. - ``runkernel``: Simple program that runs compiled ARTIQ kernels on the host and displays RTIO operations. Useful for testing without hardware.
@ -68,3 +51,12 @@ Use ``nix develop`` in this repository to enter a development shell.
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``. If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``. Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.
### Pre-Commit Hooks
You are strongly recommended to use the provided pre-commit hooks to automatically reformat files and check for non-optimal Rust practices using Clippy. Run `pre-commit install` to install the hook and `pre-commit` will automatically run `cargo fmt` and `cargo clippy` for you.
Several things to note:
- If `cargo fmt` or `cargo clippy` returns an error, the pre-commit hook will fail. You should fix all errors before trying to commit again.
- If `cargo fmt` reformats some files, the pre-commit hook will also fail. You should review the changes and, if satisfied, try to commit again.

8
flake.lock generated
View File

@ -2,16 +2,16 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1648553562, "lastModified": 1727348695,
"narHash": "sha256-xQhRKu6h0phd56oCzGjkhHkY4eDI1XKedGqkFtlXapk=", "narHash": "sha256-J+PeFKSDV+pHL7ukkfpVzCOO7mBSrrpJ3svwBFABbhI=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "9b168e5e62406fa2e55e132f390379a6ba22b402", "rev": "1925c603f17fc89f4c8f6bf6f631a802ad85d784",
"type": "github" "type": "github"
}, },
"original": { "original": {
"owner": "NixOS", "owner": "NixOS",
"ref": "nixos-21.11", "ref": "nixos-unstable",
"repo": "nixpkgs", "repo": "nixpkgs",
"type": "github" "type": "github"
} }

109
flake.nix
View File

@ -1,29 +1,57 @@
{ {
description = "The third-generation ARTIQ compiler"; description = "The third-generation ARTIQ compiler";
inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-21.11; inputs.nixpkgs.url = github:NixOS/nixpkgs/nixos-unstable;
outputs = { self, nixpkgs }: outputs = { self, nixpkgs }:
let let
pkgs = import nixpkgs { system = "x86_64-linux"; }; pkgs = import nixpkgs { system = "x86_64-linux"; };
pkgs32 = import nixpkgs { system = "i686-linux"; };
in rec { in rec {
packages.x86_64-linux = rec { packages.x86_64-linux = rec {
llvm-nac3 = pkgs.callPackage ./nix/llvm {}; llvm-nac3 = pkgs.callPackage ./nix/llvm {};
llvm-tools-irrt = pkgs.runCommandNoCC "llvm-tools-irrt" {}
''
mkdir -p $out/bin
ln -s ${pkgs.llvmPackages_14.clang-unwrapped}/bin/clang $out/bin/clang-irrt
ln -s ${pkgs.llvmPackages_14.llvm.out}/bin/llvm-as $out/bin/llvm-as-irrt
'';
demo-linalg-stub = pkgs.rustPlatform.buildRustPackage {
name = "demo-linalg-stub";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
demo-linalg-stub32 = pkgs32.rustPlatform.buildRustPackage {
name = "demo-linalg-stub32";
src = ./nac3standalone/demo/linalg;
cargoLock = {
lockFile = ./nac3standalone/demo/linalg/Cargo.lock;
};
doCheck = false;
};
nac3artiq = pkgs.python3Packages.toPythonModule ( nac3artiq = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage { pkgs.rustPlatform.buildRustPackage rec {
name = "nac3artiq"; name = "nac3artiq";
outputs = [ "out" "runkernel" "standalone" ]; outputs = [ "out" "runkernel" "standalone" ];
src = self; src = self;
cargoLock = { lockFile = ./Cargo.lock; }; cargoLock = {
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_13.clang-unwrapped llvm-nac3 ]; lockFile = ./Cargo.lock;
};
passthru.cargoLock = cargoLock;
nativeBuildInputs = [ pkgs.python3 (pkgs.wrapClangMulti pkgs.llvmPackages_14.clang) llvm-tools-irrt pkgs.llvmPackages_14.llvm.out llvm-nac3 ];
buildInputs = [ pkgs.python3 llvm-nac3 ]; buildInputs = [ pkgs.python3 llvm-nac3 ];
checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ])) ]; checkInputs = [ (pkgs.python3.withPackages(ps: [ ps.numpy ps.scipy ])) ];
checkPhase = checkPhase =
'' ''
echo "Checking nac3standalone demos..." echo "Checking nac3standalone demos..."
pushd nac3standalone/demo pushd nac3standalone/demo
patchShebangs . patchShebangs .
./check_demos.sh export DEMO_LINALG_STUB=${demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${demo-linalg-stub32}/lib/liblinalg.a
./check_demos.sh -i686
popd popd
echo "Running Cargo tests..." echo "Running Cargo tests..."
cargoCheckHook cargoCheckHook
@ -49,21 +77,21 @@
# LLVM PGO support # LLVM PGO support
llvm-nac3-instrumented = pkgs.callPackage ./nix/llvm { llvm-nac3-instrumented = pkgs.callPackage ./nix/llvm {
stdenv = pkgs.llvmPackages_13.stdenv; stdenv = pkgs.llvmPackages_14.stdenv;
extraCmakeFlags = [ "-DLLVM_BUILD_INSTRUMENTED=IR" ]; extraCmakeFlags = [ "-DLLVM_BUILD_INSTRUMENTED=IR" ];
}; };
nac3artiq-instrumented = pkgs.python3Packages.toPythonModule ( nac3artiq-instrumented = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage { pkgs.rustPlatform.buildRustPackage {
name = "nac3artiq-instrumented"; name = "nac3artiq-instrumented";
src = self; src = self;
cargoLock = { lockFile = ./Cargo.lock; }; inherit (nac3artiq) cargoLock;
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_13.clang-unwrapped llvm-nac3-instrumented ]; nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-instrumented ];
buildInputs = [ pkgs.python3 llvm-nac3-instrumented ]; buildInputs = [ pkgs.python3 llvm-nac3-instrumented ];
cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ]; cargoBuildFlags = [ "--package" "nac3artiq" "--features" "init-llvm-profile" ];
doCheck = false; doCheck = false;
configurePhase = configurePhase =
'' ''
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="-C link-arg=-L${pkgs.llvmPackages_13.compiler-rt}/lib/linux -C link-arg=-lclang_rt.profile-x86_64" export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS="-C link-arg=-L${pkgs.llvmPackages_14.compiler-rt}/lib/linux -C link-arg=-lclang_rt.profile-x86_64"
''; '';
installPhase = installPhase =
'' ''
@ -75,11 +103,35 @@
); );
nac3artiq-profile = pkgs.stdenvNoCC.mkDerivation { nac3artiq-profile = pkgs.stdenvNoCC.mkDerivation {
name = "nac3artiq-profile"; name = "nac3artiq-profile";
src = self; srcs = [
buildInputs = [ (python3-mimalloc.withPackages(ps: [ ps.numpy nac3artiq-instrumented ])) pkgs.lld_13 pkgs.llvmPackages_13.libllvm ]; (pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "sipyco";
rev = "939f84f9b5eef7efbf7423c735d1834783b6140e";
sha256 = "sha256-15Nun4EY35j+6SPZkjzZtyH/ncxLS60KuGJjFh5kSTc=";
})
(pkgs.fetchFromGitHub {
owner = "m-labs";
repo = "artiq";
rev = "923ca3377d42c815f979983134ec549dc39d3ca0";
sha256 = "sha256-oJoEeNEeNFSUyh6jXG8Tzp6qHVikeHS0CzfE+mODPgw=";
})
];
buildInputs = [
(python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ps.jsonschema ps.lmdb nac3artiq-instrumented ]))
pkgs.llvmPackages_14.llvm.out
];
phases = [ "buildPhase" "installPhase" ]; phases = [ "buildPhase" "installPhase" ];
# TODO: get more representative code. buildPhase =
buildPhase = "python $src/nac3artiq/demo/demo.py"; ''
srcs=($srcs)
sipyco=''${srcs[0]}
artiq=''${srcs[1]}
export PYTHONPATH=$sipyco:$artiq
python -m artiq.frontend.artiq_ddb_template $artiq/artiq/examples/nac3devices/nac3devices.json > device_db.py
cp $artiq/artiq/examples/nac3devices/nac3devices.py .
python -m artiq.frontend.artiq_compile nac3devices.py
'';
installPhase = installPhase =
'' ''
mkdir $out mkdir $out
@ -87,15 +139,15 @@
''; '';
}; };
llvm-nac3-pgo = pkgs.callPackage ./nix/llvm { llvm-nac3-pgo = pkgs.callPackage ./nix/llvm {
stdenv = pkgs.llvmPackages_13.stdenv; stdenv = pkgs.llvmPackages_14.stdenv;
extraCmakeFlags = [ "-DLLVM_PROFDATA_FILE=${nac3artiq-profile}/llvm.profdata" ]; extraCmakeFlags = [ "-DLLVM_PROFDATA_FILE=${nac3artiq-profile}/llvm.profdata" ];
}; };
nac3artiq-pgo = pkgs.python3Packages.toPythonModule ( nac3artiq-pgo = pkgs.python3Packages.toPythonModule (
pkgs.rustPlatform.buildRustPackage { pkgs.rustPlatform.buildRustPackage {
name = "nac3artiq-pgo"; name = "nac3artiq-pgo";
src = self; src = self;
cargoLock = { lockFile = ./Cargo.lock; }; inherit (nac3artiq) cargoLock;
nativeBuildInputs = [ pkgs.python3 pkgs.llvmPackages_13.clang-unwrapped llvm-nac3-pgo ]; nativeBuildInputs = [ pkgs.python3 packages.x86_64-linux.llvm-tools-irrt llvm-nac3-pgo ];
buildInputs = [ pkgs.python3 llvm-nac3-pgo ]; buildInputs = [ pkgs.python3 llvm-nac3-pgo ];
cargoBuildFlags = [ "--package" "nac3artiq" ]; cargoBuildFlags = [ "--package" "nac3artiq" ];
cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ]; cargoTestFlags = [ "--package" "nac3ast" "--package" "nac3parser" "--package" "nac3core" "--package" "nac3artiq" ];
@ -111,22 +163,29 @@
packages.x86_64-w64-mingw32 = import ./nix/windows { inherit pkgs; }; packages.x86_64-w64-mingw32 = import ./nix/windows { inherit pkgs; };
devShell.x86_64-linux = pkgs.mkShell { devShells.x86_64-linux.default = pkgs.mkShell {
name = "nac3-dev-shell"; name = "nac3-dev-shell";
buildInputs = with pkgs; [ buildInputs = with pkgs; [
# build dependencies # build dependencies
packages.x86_64-linux.llvm-nac3 packages.x86_64-linux.llvm-nac3
llvmPackages_13.clang-unwrapped # IRRT (pkgs.wrapClangMulti llvmPackages_14.clang) llvmPackages_14.llvm.out # for running nac3standalone demos
packages.x86_64-linux.llvm-tools-irrt
cargo cargo
rustc rustc
# runtime dependencies # runtime dependencies
lld_13 lld_14 # for running kernels on the host
(packages.x86_64-linux.python3-mimalloc.withPackages(ps: [ ps.numpy ])) (packages.x86_64-linux.python3-mimalloc.withPackages(ps: [ ps.numpy ps.scipy ]))
# development tools # development tools
cargo-insta cargo-insta
clippy clippy
pre-commit
rustfmt rustfmt
]; ];
shellHook =
''
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
export DEMO_LINALG_STUB32=${packages.x86_64-linux.demo-linalg-stub32}/lib/liblinalg.a
'';
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";
@ -139,15 +198,15 @@
}; };
hydraJobs = { hydraJobs = {
inherit (packages.x86_64-linux) llvm-nac3 nac3artiq; inherit (packages.x86_64-linux) llvm-nac3 nac3artiq nac3artiq-pgo;
llvm-nac3-msys2 = packages.x86_64-w64-mingw32.llvm-nac3; llvm-nac3-msys2 = packages.x86_64-w64-mingw32.llvm-nac3;
nac3artiq-msys2 = packages.x86_64-w64-mingw32.nac3artiq; nac3artiq-msys2 = packages.x86_64-w64-mingw32.nac3artiq;
lld-msys2 = packages.x86_64-w64-mingw32.lld; nac3artiq-msys2-pkg = packages.x86_64-w64-mingw32.nac3artiq-pkg;
}; };
}; };
nixConfig = { nixConfig = {
binaryCachePublicKeys = ["nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc="]; extra-trusted-public-keys = "nixbld.m-labs.hk-1:5aSRVA5b320xbNvu30tqxVPXpld73bhtOeH6uAjRyHc=";
binaryCaches = ["https://nixbld.m-labs.hk" "https://cache.nixos.org"]; extra-substituters = "https://nixbld.m-labs.hk";
}; };
} }

56
nac3.svg Normal file
View File

@ -0,0 +1,56 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
id="a"
width="128"
height="128"
viewBox="0 0 95.99999 95.99999"
version="1.1"
sodipodi:docname="nac3.svg"
inkscape:version="1.1.1 (3bf5ae0d25, 2021-09-20)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs11" />
<sodipodi:namedview
id="namedview9"
pagecolor="#ffffff"
bordercolor="#666666"
borderopacity="1.0"
inkscape:pageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:document-units="mm"
showgrid="false"
units="px"
width="128px"
inkscape:zoom="5.9448568"
inkscape:cx="60.472441"
inkscape:cy="60.556547"
inkscape:window-width="2560"
inkscape:window-height="1371"
inkscape:window-x="0"
inkscape:window-y="32"
inkscape:window-maximized="1"
inkscape:current-layer="a" />
<rect
x="40.072601"
y="-26.776209"
width="55.668747"
height="55.668747"
transform="matrix(0.71803815,0.69600374,-0.71803815,0.69600374,0,0)"
style="fill:#be211e;stroke:#000000;stroke-width:4.37375px;stroke-linecap:round;stroke-linejoin:round"
id="rect2" />
<line
x1="38.00692"
y1="63.457153"
x2="57.993061"
y2="63.457153"
style="fill:none;stroke:#000000;stroke-width:4.37269px;stroke-linecap:round;stroke-linejoin:round"
id="line4" />
<path
d="m 48.007301,57.843329 c -1.943097,0 -3.877522,-0.41727 -5.686157,-1.246007 -3.218257,-1.474616 -5.650382,-4.075418 -6.849639,-7.323671 -2.065624,-5.588921 -1.192751,-10.226647 2.575258,-13.827 0.611554,-0.584909 1.518048,-0.773041 2.323689,-0.488206 0.80673,0.286405 1.369495,0.998486 1.447563,1.827234 0.237469,2.549302 2.439719,5.917376 4.28414,6.55273 0.396859,0.13506 0.820953,-0.05859 1.097084,-0.35222 0.339254,-0.360754 0.451065,-0.961893 -1.013597,-3.191372 -2.089851,-3.181137 -4.638728,-8.754903 -0.262407,-15.069853 0.494457,-0.713491 1.384673,-1.068907 2.256469,-0.909156 0.871795,0.161332 1.583757,0.806404 1.752251,1.651189 0.716448,3.591862 2.962357,6.151755 5.199306,8.023138 1.935503,1.61861 4.344688,3.867387 5.435687,7.096643 2.283183,6.758017 -1.202511,14.114988 -8.060822,16.494025 -1.467083,0.509226 -2.98513,0.762536 -4.498836,0.762536 z M 39.358865,40.002192 c -0.304711,0.696206 -0.541636,2.080524 -0.56865,2.237454 -0.330316,1.918771 0.168305,3.803963 0.846157,5.539951 0.856828,2.19436 2.437543,3.942467 4.583411,4.925713 2.143691,0.981675 4.554131,1.097816 6.789992,0.322666 4.571485,-1.586549 6.977584,-6.532238 5.363036,-11.02597 v -5.27e-4 C 55.455481,39.447968 54.023463,38.162043 52.221335,36.65432 50.876945,35.529534 49.409662,33.987726 48.417983,32.135555 48.01343,31.37996 47.79547,30.34303 47.76669,29.413263 c -0.187481,0.669514 -0.212441,2.325923 -0.150396,2.93691 0.179209,1.764456 1.333476,3.644546 2.340611,5.171243 1.311568,1.988179 2.72058,6.037272 0.459681,8.367985 -1.54192,1.58953 -4.038511,2.052034 -5.839973,1.38492 -2.398314,-0.888147 -3.942744,-2.690627 -4.941118,-4.768029 -0.121194,-0.25217 -0.532464,-1.174187 -0.276619,-2.5041 z"
id="path6"
style="stroke-width:1.09317" />
</svg>

After

Width:  |  Height:  |  Size: 3.3 KiB

View File

@ -2,23 +2,20 @@
name = "nac3artiq" name = "nac3artiq"
version = "0.1.0" version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2018" edition = "2021"
[lib] [lib]
name = "nac3artiq" name = "nac3artiq"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
pyo3 = { version = "0.14", features = ["extension-module"] } itertools = "0.13"
parking_lot = "0.11" pyo3 = { version = "0.21", features = ["extension-module", "gil-refs"] }
tempfile = "3" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } tempfile = "3.13"
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
nac3ld = { path = "../nac3ld" }
[dependencies.inkwell]
version = "0.1.0-beta.4"
default-features = false
features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[features] [features]
init-llvm-profile = [] init-llvm-profile = []
no-escape-analysis = ["nac3core/no-escape-analysis"]

View File

@ -1,10 +1,4 @@
from min_artiq import * from min_artiq import *
from numpy import int32, int64
@extern
def output_int(x: int32):
...
@nac3 @nac3

View File

@ -18,6 +18,13 @@ class EmbeddingMap:
"SPIError", "SPIError",
"0:ZeroDivisionError", "0:ZeroDivisionError",
"0:IndexError", "0:IndexError",
"0:ValueError",
"0:RuntimeError",
"0:AssertionError",
"0:KeyError",
"0:NotImplementedError",
"0:OverflowError",
"0:IOError",
"0:UnwrapNoneError"]) "0:UnwrapNoneError"])
def preallocate_runtime_exception_names(self, names): def preallocate_runtime_exception_names(self, names):

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class EmptyList:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@rpc
def get_empty(self) -> list[int32]:
return []
@kernel
def run(self):
a: list[int32] = self.get_empty()
if a != []:
raise ValueError
if __name__ == "__main__":
EmptyList().run()

View File

@ -10,7 +10,7 @@ from embedding_map import EmbeddingMap
__all__ = [ __all__ = [
"Kernel", "KernelInvariant", "virtual", "Kernel", "KernelInvariant", "virtual", "ConstGeneric",
"Option", "Some", "none", "UnwrapNoneError", "Option", "Some", "none", "UnwrapNoneError",
"round64", "floor64", "ceil64", "round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3", "extern", "kernel", "portable", "nac3",
@ -67,6 +67,12 @@ def Some(v: T) -> Option[T]:
none = Option(None) none = Option(None)
class _ConstGenericMarker:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
def round64(x): def round64(x):
return round(x) return round(x)
@ -80,7 +86,13 @@ def ceil64(x):
import device_db import device_db
core_arguments = device_db.device_db["core"]["arguments"] core_arguments = device_db.device_db["core"]["arguments"]
compiler = nac3artiq.NAC3(core_arguments["target"]) artiq_builtins = {
"none": none,
"virtual": virtual,
"_ConstGenericMarker": _ConstGenericMarker,
"Option": Option,
}
compiler = nac3artiq.NAC3(core_arguments["target"], artiq_builtins)
allow_registration = True allow_registration = True
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side. # Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
registered_functions = set() registered_functions = set()
@ -100,10 +112,15 @@ def extern(function):
register_function(function) register_function(function)
return function return function
def rpc(function):
"""Decorates a function declaration defined by the core device runtime.""" def rpc(arg=None, flags={}):
register_function(function) """Decorates a function or method to be executed on the host interpreter."""
return function if arg is None:
def inner_decorator(function):
return rpc(function, flags)
return inner_decorator
register_function(arg)
return arg
def kernel(function_or_method): def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device.""" """Decorates a function or method to be executed on the core device."""

26
nac3artiq/demo/str_abi.py Normal file
View File

@ -0,0 +1,26 @@
from min_artiq import *
from numpy import ndarray, zeros as np_zeros
@nac3
class StrFail:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def hello(self, arg: str):
pass
@kernel
def consume_ndarray(self, arg: ndarray[str, 1]):
pass
def run(self):
self.hello("world")
self.consume_ndarray(np_zeros([10], dtype=str))
if __name__ == "__main__":
StrFail().run()

View File

@ -0,0 +1,24 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
core: KernelInvariant[Core]
attr1: KernelInvariant[str]
attr2: KernelInvariant[int32]
def __init__(self):
self.core = Core()
self.attr2 = 32
self.attr1 = "SAMPLE"
@kernel
def run(self):
print_int32(self.attr2)
self.attr1
if __name__ == "__main__":
Demo().run()

View File

@ -0,0 +1,40 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,23 @@
use inkwell::{values::BasicValueEnum, AddressSpace, AtomicOrdering}; use itertools::Either;
use nac3core::codegen::CodeGenContext;
use nac3core::{
codegen::CodeGenContext,
inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
},
};
/// Functions for manipulating the timeline.
pub trait TimeFns { pub trait TimeFns {
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx>; /// Emits LLVM IR for `now_mu`.
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>); fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
fn emit_delay_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, dt: BasicValueEnum<'ctx>);
/// Emits LLVM IR for `at_mu`.
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>);
/// Emits LLVM IR for `delay_mu`.
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>);
} }
pub struct NowPinningTimeFns64 {} pub struct NowPinningTimeFns64 {}
@ -12,141 +25,143 @@ pub struct NowPinningTimeFns64 {}
// For FPGA design reasons, on VexRiscv with 64-bit data bus, the "now" CSR is split into two 32-bit // For FPGA design reasons, on VexRiscv with 64-bit data bus, the "now" CSR is split into two 32-bit
// values that are each padded to 64-bits. // values that are each padded to 64-bits.
impl TimeFns for NowPinningTimeFns64 { impl TimeFns for NowPinningTimeFns64 {
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> { fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = let now_hiptr = ctx
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr"); .builder
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
let now_loptr = unsafe { .map(BasicValueEnum::into_pointer_value)
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep") .unwrap();
};
if let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = ( let now_loptr = unsafe {
ctx.builder.build_load(now_hiptr, "now_hi"), ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
ctx.builder.build_load(now_loptr, "now_lo"),
) {
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi");
let shifted_hi = ctx.builder.build_left_shift(
zext_hi,
i64_type.const_int(32, false),
"now_shifted_zext_hi",
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo");
ctx.builder.build_or(shifted_hi, zext_lo, "now_or").into()
} else {
unreachable!();
}
} else {
unreachable!();
} }
.unwrap();
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
} }
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
if let BasicValueEnum::IntValue(time) = t { let time = t.into_int_value();
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"), let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type, i32_type,
"now_trunc", "",
); )
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); .unwrap();
let now = ctx let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
.module let now = ctx
.get_global("now") .module
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .get_global("now")
let now_hiptr = ctx.builder.build_bitcast( .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
now, let now_hiptr = ctx
i32_type.ptr_type(AddressSpace::Generic), .builder
"now_bitcast", .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
); .map(BasicValueEnum::into_pointer_value)
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { .unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_gep") let now_loptr = unsafe {
}; ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
ctx.builder
.build_store(now_hiptr, time_hi)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} else {
unreachable!();
}
} else {
unreachable!();
} }
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} }
fn emit_delay_mu<'ctx, 'a>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
dt: BasicValueEnum<'ctx>,
) {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = let now_hiptr = ctx
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::Generic), "now_hiptr"); .builder
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
let now_loptr = unsafe { .map(BasicValueEnum::into_pointer_value)
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now_loptr") .unwrap();
};
if let (
BasicValueEnum::IntValue(now_hi),
BasicValueEnum::IntValue(now_lo),
BasicValueEnum::IntValue(dt),
) = (
ctx.builder.build_load(now_hiptr, "now_hi"),
ctx.builder.build_load(now_loptr, "now_lo"),
dt,
) {
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "now_zext_hi");
let shifted_hi = ctx.builder.build_left_shift(
zext_hi,
i64_type.const_int(32, false),
"now_shifted_zext_hi",
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "now_zext_lo");
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now_or");
let time = ctx.builder.build_int_add(now_val, dt, "now_add"); let now_loptr = unsafe {
let time_hi = ctx.builder.build_int_truncate( ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
ctx.builder.build_right_shift( }
time, .unwrap();
i64_type.const_int(32, false),
false,
"now_lshr",
),
i32_type,
"now_trunc",
);
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc");
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_right_shift(time, i64_type.const_int(32, false), false, "")
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .unwrap(),
.unwrap(); i32_type,
ctx.builder "time.hi",
.build_store(now_loptr, time_lo) )
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent) .unwrap();
.unwrap(); let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
} else {
unreachable!(); ctx.builder
} .build_store(now_hiptr, time_hi)
} else { .unwrap()
unreachable!(); .set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
}; .unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} }
} }
@ -155,68 +170,67 @@ pub static NOW_PINNING_TIME_FNS_64: NowPinningTimeFns64 = NowPinningTimeFns64 {}
pub struct NowPinningTimeFns {} pub struct NowPinningTimeFns {}
impl TimeFns for NowPinningTimeFns { impl TimeFns for NowPinningTimeFns {
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> { fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); let now_raw = ctx
if let BasicValueEnum::IntValue(now_raw) = now_raw { .builder
let i64_32 = i64_type.const_int(32, false); .build_load(now.as_pointer_value(), "now")
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); .map(BasicValueEnum::into_int_value)
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr"); .unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_or").into()
} else { let i64_32 = i64_type.const_int(32, false);
unreachable!(); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
} let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
} }
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
if let BasicValueEnum::IntValue(time) = t {
let time_hi = ctx.builder.build_int_truncate( let time = t.into_int_value();
ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
i32_type, i32_type,
"now_trunc", "time.hi",
); )
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); .unwrap();
let now = ctx let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc").unwrap();
.module let now = ctx
.get_global("now") .module
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .get_global("now")
let now_hiptr = ctx.builder.build_bitcast( .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
now, let now_hiptr = ctx
i32_type.ptr_type(AddressSpace::Generic), .builder
"now_bitcast", .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
); .map(BasicValueEnum::into_pointer_value)
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { .unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep") let now_loptr = unsafe {
}; ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
ctx.builder
.build_store(now_hiptr, time_hi)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} else {
unreachable!();
}
} else {
unreachable!();
} }
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} }
fn emit_delay_mu<'ctx, 'a>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
dt: BasicValueEnum<'ctx>,
) {
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
@ -224,41 +238,47 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now"); let now_raw = ctx
if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) { .builder
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now_shl"); .build_load(now.as_pointer_value(), "")
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now_lshr"); .map(BasicValueEnum::into_int_value)
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_or"); .unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "now_add");
let time_hi = ctx.builder.build_int_truncate( let dt = dt.into_int_value();
ctx.builder.build_right_shift(time, i64_32, false, "now_lshr"),
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type, i32_type,
"now_trunc", "now_trunc",
); )
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "now_trunc"); .unwrap();
let now_hiptr = ctx.builder.build_bitcast( let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
now, let now_hiptr = ctx
i32_type.ptr_type(AddressSpace::Generic), .builder
"now_bitcast", .build_bit_cast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
); .map(BasicValueEnum::into_pointer_value)
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr { .unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now_gep") let now_loptr = unsafe {
}; ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
ctx.builder
.build_store(now_hiptr, time_hi)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} else {
unreachable!();
}
} else {
unreachable!();
} }
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
ctx.builder
.build_store(now_loptr, time_lo)
.unwrap()
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} }
} }
@ -267,14 +287,18 @@ pub static NOW_PINNING_TIME_FNS: NowPinningTimeFns = NowPinningTimeFns {};
pub struct ExternTimeFns {} pub struct ExternTimeFns {}
impl TimeFns for ExternTimeFns { impl TimeFns for ExternTimeFns {
fn emit_now_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>) -> BasicValueEnum<'ctx> { fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx> {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
}); });
ctx.builder.build_call(now_mu, &[], "now_mu").try_as_basic_value().left().unwrap() ctx.builder
.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
} }
fn emit_at_mu<'ctx, 'a>(&self, ctx: &mut CodeGenContext<'ctx, 'a>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| { let at_mu = ctx.module.get_function("at_mu").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"at_mu", "at_mu",
@ -282,14 +306,10 @@ impl TimeFns for ExternTimeFns {
None, None,
) )
}); });
ctx.builder.build_call(at_mu, &[t.into()], "at_mu"); ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
} }
fn emit_delay_mu<'ctx, 'a>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
dt: BasicValueEnum<'ctx>,
) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"delay_mu", "delay_mu",
@ -297,7 +317,7 @@ impl TimeFns for ExternTimeFns {
None, None,
) )
}); });
ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu"); ctx.builder.build_call(delay_mu, &[dt.into()], "delay_mu").unwrap();
} }
} }

View File

@ -2,7 +2,7 @@
name = "nac3ast" name = "nac3ast"
version = "0.1.0" version = "0.1.0"
authors = ["RustPython Team", "M-Labs"] authors = ["RustPython Team", "M-Labs"]
edition = "2018" edition = "2021"
[features] [features]
default = ["constant-optimization", "fold"] default = ["constant-optimization", "fold"]
@ -10,7 +10,7 @@ constant-optimization = ["fold"]
fold = [] fold = []
[dependencies] [dependencies]
lazy_static = "1.4.0" lazy_static = "1.5"
parking_lot = "0.11.1" parking_lot = "0.12"
string-interner = "0.13.0" string-interner = "0.17"
fxhash = "0.2.1" fxhash = "0.2"

File diff suppressed because it is too large Load Diff

View File

@ -28,12 +28,12 @@ impl From<bool> for Constant {
} }
impl From<i32> for Constant { impl From<i32> for Constant {
fn from(i: i32) -> Constant { fn from(i: i32) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
impl From<i64> for Constant { impl From<i64> for Constant {
fn from(i: i64) -> Constant { fn from(i: i64) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
@ -50,6 +50,7 @@ pub enum ConversionFlag {
} }
impl ConversionFlag { impl ConversionFlag {
#[must_use]
pub fn try_from_byte(b: u8) -> Option<Self> { pub fn try_from_byte(b: u8) -> Option<Self> {
match b { match b {
b's' => Some(Self::Str), b's' => Some(Self::Str),
@ -69,6 +70,7 @@ pub struct ConstantOptimizer {
#[cfg(feature = "constant-optimization")] #[cfg(feature = "constant-optimization")]
impl ConstantOptimizer { impl ConstantOptimizer {
#[inline] #[inline]
#[must_use]
pub fn new() -> Self { pub fn new() -> Self {
Self { _priv: () } Self { _priv: () }
} }
@ -85,33 +87,22 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> { fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node { match node.node {
crate::ExprKind::Tuple { elts, ctx } => { crate::ExprKind::Tuple { elts, ctx } => {
let elts = elts let elts =
.into_iter() elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
.map(|x| self.fold_expr(x)) let expr =
.collect::<Result<Vec<_>, _>>()?; if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
let expr = if elts let tuple = elts
.iter() .into_iter()
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) .map(|e| match e.node {
{ crate::ExprKind::Constant { value, .. } => value,
let tuple = elts _ => unreachable!(),
.into_iter() })
.map(|e| match e.node { .collect();
crate::ExprKind::Constant { value, .. } => value, crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
_ => unreachable!(), } else {
}) crate::ExprKind::Tuple { elts, ctx }
.collect(); };
crate::ExprKind::Constant { Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
value: Constant::Tuple(tuple),
kind: None,
}
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr {
node: expr,
custom: node.custom,
location: node.location,
})
} }
_ => crate::fold::fold_expr(self, node), _ => crate::fold::fold_expr(self, node),
} }
@ -127,7 +118,7 @@ mod tests {
use crate::fold::Fold; use crate::fold::Fold;
use crate::*; use crate::*;
let location = Location::new(0, 0, Default::default()); let location = Location::new(0, 0, FileName::default());
let custom = (); let custom = ();
let ast = Located { let ast = Located {
location, location,
@ -138,18 +129,12 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 1.into(), kind: None },
value: 1.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 2.into(), kind: None },
value: 2.into(),
kind: None,
},
}, },
Located { Located {
location, location,
@ -160,26 +145,17 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 3.into(), kind: None },
value: 3.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 4.into(), kind: None },
value: 4.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 5.into(), kind: None },
value: 5.into(),
kind: None,
},
}, },
], ],
}, },
@ -187,9 +163,7 @@ mod tests {
], ],
}, },
}; };
let new_ast = ConstantOptimizer::new() let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
.fold_expr(ast)
.unwrap_or_else(|e| match e {});
assert_eq!( assert_eq!(
new_ast, new_ast,
Located { Located {
@ -199,11 +173,7 @@ mod tests {
value: Constant::Tuple(vec![ value: Constant::Tuple(vec![
1.into(), 1.into(),
2.into(), 2.into(),
Constant::Tuple(vec![ Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
3.into(),
4.into(),
5.into(),
])
]), ]),
kind: None kind: None
}, },

View File

@ -64,11 +64,4 @@ macro_rules! simple_fold {
}; };
} }
simple_fold!( simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);
usize,
String,
bool,
StrRef,
constant::Constant,
constant::ConversionFlag
);

View File

@ -2,6 +2,7 @@ use crate::{Constant, ExprKind};
impl<U> ExprKind<U> { impl<U> ExprKind<U> {
/// Returns a short name for the node suitable for use in error messages. /// Returns a short name for the node suitable for use in error messages.
#[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self { match self {
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => { ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
@ -34,10 +35,7 @@ impl<U> ExprKind<U> {
ExprKind::Starred { .. } => "starred", ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice", ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => { ExprKind::JoinedStr { values } => {
if values if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
.iter()
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
{
"f-string expression" "f-string expression"
} else { } else {
"literal" "literal"

View File

@ -1,3 +1,19 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
@ -9,6 +25,6 @@ mod impls;
mod location; mod location;
pub use ast_gen::*; pub use ast_gen::*;
pub use location::{Location, FileName}; pub use location::{FileName, Location};
pub type Suite<U = ()> = Vec<Stmt<U>>; pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,8 +1,9 @@
//! Datatypes to support source location information. //! Datatypes to support source location information.
use crate::ast_gen::StrRef; use crate::ast_gen::StrRef;
use std::cmp::Ordering;
use std::fmt; use std::fmt;
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct FileName(pub StrRef); pub struct FileName(pub StrRef);
impl Default for FileName { impl Default for FileName {
fn default() -> Self { fn default() -> Self {
@ -17,16 +18,38 @@ impl From<String> for FileName {
} }
/// A location somewhere in the sourcecode. /// A location somewhere in the sourcecode.
#[derive(Clone, Copy, Debug, Default, PartialEq)] #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Location { pub struct Location {
pub row: usize, pub row: usize,
pub column: usize, pub column: usize,
pub file: FileName pub file: FileName,
} }
impl fmt::Display for Location { impl fmt::Display for Location {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}: line {} column {}", self.file.0, self.row, self.column) write!(f, "{}:{}:{}", self.file.0, self.row, self.column)
}
}
impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal {
return file_cmp;
}
let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal {
return row_cmp;
}
self.column.cmp(&other.column)
}
}
impl PartialOrd for Location {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
} }
} }
@ -53,23 +76,22 @@ impl Location {
) )
} }
} }
Visualize { Visualize { loc: *self, line, desc }
loc: *self,
line,
desc,
}
} }
} }
impl Location { impl Location {
#[must_use]
pub fn new(row: usize, column: usize, file: FileName) -> Self { pub fn new(row: usize, column: usize, file: FileName) -> Self {
Location { row, column, file } Location { row, column, file }
} }
#[must_use]
pub fn row(&self) -> usize { pub fn row(&self) -> usize {
self.row self.row
} }
#[must_use]
pub fn column(&self) -> usize { pub fn column(&self) -> usize {
self.column self.column
} }

View File

@ -2,25 +2,30 @@
name = "nac3core" name = "nac3core"
version = "0.1.0" version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2018" edition = "2021"
[features]
no-escape-analysis = []
[dependencies] [dependencies]
itertools = "0.10.1" itertools = "0.13"
crossbeam = "0.8.1" crossbeam = "0.8"
parking_lot = "0.11.1" indexmap = "2.6"
rayon = "1.5.1" parking_lot = "0.12"
slab = "0.4.6" rayon = "1.10"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
strum = "0.26"
strum_macros = "0.26"
[dependencies.inkwell] [dependencies.inkwell]
version = "0.1.0-beta.4" version = "0.5"
default-features = false default-features = false
features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"] features = ["llvm14-0-prefer-dynamic", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]
[dev-dependencies] [dev-dependencies]
test-case = "1.2.0" test-case = "1.2.0"
indoc = "1.0" indoc = "2.0"
insta = "=1.11.0" insta = "=1.11.0"
[build-dependencies] [build-dependencies]
regex = "1" regex = "1.10"

View File

@ -1,4 +1,3 @@
use regex::Regex;
use std::{ use std::{
env, env,
fs::File, fs::File,
@ -7,30 +6,55 @@ use std::{
process::{Command, Stdio}, process::{Command, Stdio},
}; };
use regex::Regex;
fn main() { fn main() {
const FILE: &str = "src/codegen/irrt/irrt.c";
println!("cargo:rerun-if-changed={}", FILE);
let out_dir = env::var("OUT_DIR").unwrap(); let out_dir = env::var("OUT_DIR").unwrap();
let out_path = Path::new(&out_dir); let out_dir = Path::new(&out_dir);
let irrt_dir = Path::new("irrt");
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
/* /*
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode. * HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
* Compiling for WASM32 and filtering the output with regex is the closest we can get. * Compiling for WASM32 and filtering the output with regex is the closest we can get.
*/ */
let mut flags: Vec<&str> = vec![
const FLAG: &[&str] = &[
"--target=wasm32", "--target=wasm32",
FILE, "-x",
"-O3", "c++",
"-std=c++20",
"-fno-discard-value-names",
"-fno-exceptions",
"-fno-rtti",
"-emit-llvm", "-emit-llvm",
"-S", "-S",
"-Wall", "-Wall",
"-Wextra", "-Wextra",
"-o", "-o",
"-", "-",
"-I",
irrt_dir.to_str().unwrap(),
irrt_cpp_path.to_str().unwrap(),
]; ];
let output = Command::new("clang")
.args(FLAG) match env::var("PROFILE").as_deref() {
Ok("debug") => {
flags.push("-O0");
flags.push("-DIRRT_DEBUG_ASSERT");
}
Ok("release") => {
flags.push("-O3");
}
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
}
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
// Compile IRRT and capture the LLVM IR output
let output = Command::new("clang-irrt")
.args(flags)
.output() .output()
.map(|o| { .map(|o| {
assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap()); assert!(o.status.success(), "{}", std::str::from_utf8(&o.stderr).unwrap());
@ -42,9 +66,19 @@ fn main() {
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n"); let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
let mut filtered_output = String::with_capacity(output.len()); let mut filtered_output = String::with_capacity(output.len());
let regex_filter = regex::Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap(); // Filter out irrelevant IR
//
// Regex:
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
// - `(?m:^@.+?=.+$)` captures global constants
let regex_filter = Regex::new(
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
)
.unwrap();
for f in regex_filter.captures_iter(&output) { for f in regex_filter.captures_iter(&output) {
assert!(f.len() == 1); assert_eq!(f.len(), 1);
filtered_output.push_str(&f[0]); filtered_output.push_str(&f[0]);
filtered_output.push('\n'); filtered_output.push('\n');
} }
@ -53,20 +87,24 @@ fn main() {
.unwrap() .unwrap()
.replace_all(&filtered_output, ""); .replace_all(&filtered_output, "");
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT"); // For debugging
if env::var("DEBUG_DUMP_IRRT").is_ok() { // Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
let mut file = File::create(out_path.join("irrt.ll")).unwrap(); const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
if env::var(DEBUG_DUMP_IRRT).is_ok() {
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
file.write_all(output.as_bytes()).unwrap(); file.write_all(output.as_bytes()).unwrap();
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
file.write_all(filtered_output.as_bytes()).unwrap(); file.write_all(filtered_output.as_bytes()).unwrap();
} }
let mut llvm_as = Command::new("llvm-as") let mut llvm_as = Command::new("llvm-as-irrt")
.stdin(Stdio::piped()) .stdin(Stdio::piped())
.arg("-o") .arg("-o")
.arg(out_path.join("irrt.bc")) .arg(out_dir.join("irrt.bc"))
.spawn() .spawn()
.unwrap(); .unwrap();
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap(); llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
assert!(llvm_as.wait().unwrap().success()) assert!(llvm_as.wait().unwrap().success());
} }

6
nac3core/irrt/irrt.cpp Normal file
View File

@ -0,0 +1,6 @@
#include "irrt/exception.hpp"
#include "irrt/int_types.hpp"
#include "irrt/list.hpp"
#include "irrt/math.hpp"
#include "irrt/ndarray.hpp"
#include "irrt/slice.hpp"

View File

@ -0,0 +1,9 @@
#pragma once
#include "irrt/int_types.hpp"
template<typename SizeT>
struct CSlice {
uint8_t* base;
SizeT len;
};

View File

@ -0,0 +1,25 @@
#pragma once
// Set in nac3core/build.rs
#ifdef IRRT_DEBUG_ASSERT
#define IRRT_DEBUG_ASSERT_BOOL true
#else
#define IRRT_DEBUG_ASSERT_BOOL false
#endif
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
#define debug_assert_eq(SizeT, lhs, rhs) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if ((lhs) != (rhs)) { \
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
} \
}
#define debug_assert(SizeT, expr) \
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
if (!(expr)) { \
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
} \
}

View File

@ -0,0 +1,82 @@
#pragma once
#include "irrt/cslice.hpp"
#include "irrt/int_types.hpp"
/**
* @brief The int type of ARTIQ exception IDs.
*/
typedef int32_t ExceptionId;
/*
* Set of exceptions C++ IRRT can use.
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
*/
extern "C" {
ExceptionId EXN_INDEX_ERROR;
ExceptionId EXN_VALUE_ERROR;
ExceptionId EXN_ASSERTION_ERROR;
ExceptionId EXN_TYPE_ERROR;
}
/**
* @brief Extern function to `__nac3_raise`
*
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
*/
extern "C" void __nac3_raise(void* err);
namespace {
/**
* @brief NAC3's Exception struct
*/
template<typename SizeT>
struct Exception {
ExceptionId id;
CSlice<SizeT> filename;
int32_t line;
int32_t column;
CSlice<SizeT> function;
CSlice<SizeT> msg;
int64_t params[3];
};
constexpr int64_t NO_PARAM = 0;
template<typename SizeT>
void _raise_exception_helper(ExceptionId id,
const char* filename,
int32_t line,
const char* function,
const char* msg,
int64_t param0,
int64_t param1,
int64_t param2) {
Exception<SizeT> e = {
.id = id,
.filename = {.base = reinterpret_cast<const uint8_t*>(filename), .len = __builtin_strlen(filename)},
.line = line,
.column = 0,
.function = {.base = reinterpret_cast<const uint8_t*>(function), .len = __builtin_strlen(function)},
.msg = {.base = reinterpret_cast<const uint8_t*>(msg), .len = __builtin_strlen(msg)},
};
e.params[0] = param0;
e.params[1] = param1;
e.params[2] = param2;
__nac3_raise(reinterpret_cast<void*>(&e));
__builtin_unreachable();
}
/**
* @brief Raise an exception with location details (location in the IRRT source files).
* @param SizeT The runtime `size_t` type.
* @param id The ID of the exception to raise.
* @param msg A global constant C-string of the error message.
*
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
* `NO_PARAM` to indicate they are unused.
*/
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
} // namespace

View File

@ -0,0 +1,22 @@
#pragma once
#if __STDC_VERSION__ >= 202000
using int8_t = _BitInt(8);
using uint8_t = unsigned _BitInt(8);
using int32_t = _BitInt(32);
using uint32_t = unsigned _BitInt(32);
using int64_t = _BitInt(64);
using uint64_t = unsigned _BitInt(64);
#else
using int8_t = _ExtInt(8);
using uint8_t = unsigned _ExtInt(8);
using int32_t = _ExtInt(32);
using uint32_t = unsigned _ExtInt(32);
using int64_t = _ExtInt(64);
using uint64_t = unsigned _ExtInt(64);
#endif
// NDArray indices are always `uint32_t`.
using NDIndex = uint32_t;
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
using SliceIndex = int32_t;

View File

@ -0,0 +1,75 @@
#pragma once
#include "irrt/int_types.hpp"
#include "irrt/math_util.hpp"
extern "C" {
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
SliceIndex dest_end,
SliceIndex dest_step,
uint8_t* dest_arr,
SliceIndex dest_arr_len,
SliceIndex src_start,
SliceIndex src_end,
SliceIndex src_step,
uint8_t* src_arr,
SliceIndex src_arr_len,
const SliceIndex size) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0)
return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca = (dest_arr == src_arr)
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|| max(src_start, src_end) < min(dest_start, dest_end));
if (need_alloca) {
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
SliceIndex src_ind = src_start;
SliceIndex dest_ind = dest_start;
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}
} // extern "C"

View File

@ -0,0 +1,93 @@
#pragma once
namespace {
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
template<typename T>
T __nac3_int_exp_impl(T base, T exp) {
T res = 1;
/* repeated squaring method */
do {
if (exp & 1) {
res *= base; /* for n odd */
}
exp >>= 1;
base *= base;
} while (exp);
return res;
}
} // namespace
#define DEF_nac3_int_exp_(T) \
T __nac3_int_exp_##T(T base, T exp) { \
return __nac3_int_exp_impl(base, exp); \
}
extern "C" {
// Putting semicolons here to make clang-format not reformat this into
// a stair shape.
DEF_nac3_int_exp_(int32_t);
DEF_nac3_int_exp_(int64_t);
DEF_nac3_int_exp_(uint32_t);
DEF_nac3_int_exp_(uint64_t);
int32_t __nac3_isinf(double x) {
return __builtin_isinf(x);
}
int32_t __nac3_isnan(double x) {
return __builtin_isnan(x);
}
double tgamma(double arg);
double __nac3_gamma(double z) {
// Handling for denormals
// | x | Python gamma(x) | C tgamma(x) |
// --- | ----------------- | --------------- | ----------- |
// (1) | nan | nan | nan |
// (2) | -inf | -inf | inf |
// (3) | inf | inf | inf |
// (4) | 0.0 | inf | inf |
// (5) | {-1.0, -2.0, ...} | inf | nan |
// (1)-(3)
if (__builtin_isinf(z) || __builtin_isnan(z)) {
return z;
}
double v = tgamma(z);
// (4)-(5)
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
}
double lgamma(double arg);
double __nac3_gammaln(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: gammaln(-inf) -> -inf
// - libm : lgamma(-inf) -> inf
if (__builtin_isinf(x)) {
return x;
}
return lgamma(x);
}
double j0(double x);
double __nac3_j0(double x) {
// libm's handling of value overflows differs from scipy:
// - scipy: j0(inf) -> nan
// - libm : j0(inf) -> 0.0
if (__builtin_isinf(x)) {
return __builtin_nan("");
}
return j0(x);
}
}

View File

@ -0,0 +1,13 @@
#pragma once
namespace {
template<typename T>
const T& max(const T& a, const T& b) {
return a > b ? a : b;
}
template<typename T>
const T& min(const T& a, const T& b) {
return a > b ? b : a;
}
} // namespace

View File

@ -0,0 +1,144 @@
#pragma once
#include "irrt/int_types.hpp"
namespace {
template<typename SizeT>
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
__builtin_assume(end_idx <= list_len);
SizeT num_elems = 1;
for (SizeT i = begin_idx; i < end_idx; ++i) {
SizeT val = list_data[i];
__builtin_assume(val > 0);
num_elems *= val;
}
return num_elems;
}
template<typename SizeT>
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
SizeT stride = 1;
for (SizeT dim = 0; dim < num_dims; dim++) {
SizeT i = num_dims - dim - 1;
__builtin_assume(dims[i] > 0);
idxs[i] = (index / stride) % dims[i];
stride *= dims[i];
}
}
template<typename SizeT>
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
SizeT idx = 0;
SizeT stride = 1;
for (SizeT i = 0; i < num_dims; ++i) {
SizeT ri = num_dims - i - 1;
if (ri < num_indices) {
idx += stride * indices[ri];
}
__builtin_assume(dims[i] > 0);
stride *= dims[ri];
}
return idx;
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
SizeT lhs_ndims,
const SizeT* rhs_dims,
SizeT rhs_ndims,
SizeT* out_dims) {
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (SizeT i = 0; i < max_ndims; ++i) {
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
SizeT* out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == nullptr) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == nullptr) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
template<typename SizeT>
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
SizeT src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
for (SizeT i = 0; i < src_ndims; ++i) {
SizeT src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
} // namespace
extern "C" {
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
uint64_t
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
}
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
}
uint32_t
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
uint64_t
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
}
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
uint32_t lhs_ndims,
const uint32_t* rhs_dims,
uint32_t rhs_ndims,
uint32_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
uint64_t lhs_ndims,
const uint64_t* rhs_dims,
uint64_t rhs_ndims,
uint64_t* out_dims) {
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
}
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
uint32_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
uint64_t src_ndims,
const NDIndex* in_idx,
NDIndex* out_idx) {
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
}
}

View File

@ -0,0 +1,28 @@
#pragma once
#include "irrt/int_types.hpp"
extern "C" {
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
SliceIndex diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,20 @@
use std::collections::HashMap;
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use crate::{ use crate::{
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::DefinitionId, toplevel::DefinitionId,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typedef::{
into_var_map, FunSignature, FuncArg, Type, TypeEnum, TypeVar, TypeVarId, Unifier,
},
}, },
}; };
use nac3parser::ast::StrRef;
use std::collections::HashMap;
pub struct ConcreteTypeStore { pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>, store: Vec<ConcreteTypeEnum>,
} }
@ -22,6 +27,7 @@ pub struct ConcreteFuncArg {
pub name: StrRef, pub name: StrRef,
pub ty: ConcreteType, pub ty: ConcreteType,
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
pub is_vararg: bool,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -43,14 +49,12 @@ pub enum ConcreteTypeEnum {
TPrimitive(Primitive), TPrimitive(Primitive),
TTuple { TTuple {
ty: Vec<ConcreteType>, ty: Vec<ConcreteType>,
}, is_vararg_ctx: bool,
TList {
ty: ConcreteType,
}, },
TObj { TObj {
obj_id: DefinitionId, obj_id: DefinitionId,
fields: HashMap<StrRef, (ConcreteType, bool)>, fields: HashMap<StrRef, (ConcreteType, bool)>,
params: HashMap<u32, ConcreteType>, params: IndexMap<TypeVarId, ConcreteType>,
}, },
TVirtual { TVirtual {
ty: ConcreteType, ty: ConcreteType,
@ -58,11 +62,15 @@ pub enum ConcreteTypeEnum {
TFunc { TFunc {
args: Vec<ConcreteFuncArg>, args: Vec<ConcreteFuncArg>,
ret: ConcreteType, ret: ConcreteType,
vars: HashMap<u32, ConcreteType>, vars: HashMap<TypeVarId, ConcreteType>,
},
TLiteral {
values: Vec<SymbolValue>,
}, },
} }
impl ConcreteTypeStore { impl ConcreteTypeStore {
#[must_use]
pub fn new() -> ConcreteTypeStore { pub fn new() -> ConcreteTypeStore {
ConcreteTypeStore { ConcreteTypeStore {
store: vec![ store: vec![
@ -80,6 +88,7 @@ impl ConcreteTypeStore {
} }
} }
#[must_use]
pub fn get(&self, cty: ConcreteType) -> &ConcreteTypeEnum { pub fn get(&self, cty: ConcreteType) -> &ConcreteTypeEnum {
&self.store[cty.0] &self.store[cty.0]
} }
@ -97,8 +106,16 @@ impl ConcreteTypeStore {
.iter() .iter()
.map(|arg| ConcreteFuncArg { .map(|arg| ConcreteFuncArg {
name: arg.name, name: arg.name,
ty: self.from_unifier_type(unifier, primitives, arg.ty, cache), ty: if arg.is_vararg {
let tuple_ty = unifier
.add_ty(TypeEnum::TTuple { ty: vec![arg.ty], is_vararg_ctx: true });
self.from_unifier_type(unifier, primitives, tuple_ty, cache)
} else {
self.from_unifier_type(unifier, primitives, arg.ty, cache)
},
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: arg.is_vararg,
}) })
.collect(), .collect(),
ret: self.from_unifier_type(unifier, primitives, signature.ret, cache), ret: self.from_unifier_type(unifier, primitives, signature.ret, cache),
@ -153,14 +170,12 @@ impl ConcreteTypeStore {
cache.insert(ty, None); cache.insert(ty, None);
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TypeEnum::TTuple { ty } => ConcreteTypeEnum::TTuple { TypeEnum::TTuple { ty, is_vararg_ctx } => ConcreteTypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|t| self.from_unifier_type(unifier, primitives, *t, cache)) .map(|t| self.from_unifier_type(unifier, primitives, *t, cache))
.collect(), .collect(),
}, is_vararg_ctx: *is_vararg_ctx,
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
}, },
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
obj_id: *obj_id, obj_id: *obj_id,
@ -194,9 +209,12 @@ impl ConcreteTypeStore {
ty: self.from_unifier_type(unifier, primitives, *ty, cache), ty: self.from_unifier_type(unifier, primitives, *ty, cache),
}, },
TypeEnum::TFunc(signature) => { TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, &*signature, cache) self.from_signature(unifier, primitives, signature, cache)
} }
_ => unreachable!(), TypeEnum::TLiteral { values, .. } => {
ConcreteTypeEnum::TLiteral { values: values.clone() }
}
_ => unreachable!("{:?}", ty_enum.get_type_name()),
}; };
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() { let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
self.store[*index] = result; self.store[*index] = result;
@ -221,7 +239,7 @@ impl ConcreteTypeStore {
return if let Some(ty) = ty { return if let Some(ty) = ty {
*ty *ty
} else { } else {
*ty = Some(unifier.get_dummy_var().0); *ty = Some(unifier.get_dummy_var().ty);
ty.unwrap() ty.unwrap()
}; };
} }
@ -243,15 +261,13 @@ impl ConcreteTypeStore {
*cache.get_mut(&cty).unwrap() = Some(ty); *cache.get_mut(&cty).unwrap() = Some(ty);
return ty; return ty;
} }
ConcreteTypeEnum::TTuple { ty } => TypeEnum::TTuple { ConcreteTypeEnum::TTuple { ty, is_vararg_ctx } => TypeEnum::TTuple {
ty: ty ty: ty
.iter() .iter()
.map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache)) .map(|cty| self.to_unifier_type(unifier, primitives, *cty, cache))
.collect(), .collect(),
is_vararg_ctx: *is_vararg_ctx,
}, },
ConcreteTypeEnum::TList { ty } => {
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
}
ConcreteTypeEnum::TVirtual { ty } => { ConcreteTypeEnum::TVirtual { ty } => {
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
} }
@ -263,10 +279,10 @@ impl ConcreteTypeStore {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
}) })
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
params: params params: into_var_map(params.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<HashMap<_, _>>(), })),
}, },
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args args: args
@ -275,14 +291,18 @@ impl ConcreteTypeStore {
name: arg.name, name: arg.name,
ty: self.to_unifier_type(unifier, primitives, arg.ty, cache), ty: self.to_unifier_type(unifier, primitives, arg.ty, cache),
default_value: arg.default_value.clone(), default_value: arg.default_value.clone(),
is_vararg: false,
}) })
.collect(), .collect(),
ret: self.to_unifier_type(unifier, primitives, *ret, cache), ret: self.to_unifier_type(unifier, primitives, *ret, cache),
vars: vars vars: into_var_map(vars.iter().map(|(&id, cty)| {
.iter() let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) TypeVar { id, ty }
.collect::<HashMap<_, _>>(), })),
}), }),
ConcreteTypeEnum::TLiteral { values, .. } => {
TypeEnum::TLiteral { values: values.clone(), loc: None }
}
}; };
let result = unifier.add_ty(result); let result = unifier.add_ty(result);
if let Some(ty) = cache.get(&cty).unwrap() { if let Some(ty) = cache.get(&cty).unwrap() {

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,193 @@
use inkwell::{
attributes::{Attribute, AttributeLoc},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
};
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Macro to generate extern function
/// Both function return type and function parameter type are `FloatValue`
///
/// Arguments:
/// * `unary/binary`: Whether the extern function requires one (unary) or two (binary) operands
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
///
/// Optional Arguments:
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
/// These will be used unless other attributes are specified
/// * `$(,$args:ident)*`: Operands of the extern function
/// The data type of these operands will be set to `FloatValue`
///
macro_rules! generate_extern_fn {
("unary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("unary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg $(,$attributes)*);
};
("binary", $fn_name:ident, $extern_fn:literal) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2, "mustprogress", "nofree", "nounwind", "willreturn", "writeonly");
};
("binary", $fn_name:ident, $extern_fn:literal $(,$attributes:literal)*) => {
generate_extern_fn!($fn_name, $extern_fn, arg1, arg2 $(,$attributes)*);
};
($fn_name:ident, $extern_fn:literal $(,$args:ident)* $(,$attributes:literal)*) => {
#[doc = concat!("Invokes the [`", stringify!($extern_fn), "`](https://en.cppreference.com/w/c/numeric/math/", stringify!($llvm_name), ") function." )]
pub fn $fn_name<'ctx>(
ctx: &CodeGenContext<'ctx, '_>
$(,$args: FloatValue<'ctx>)*,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = $extern_fn;
let llvm_f64 = ctx.ctx.f64_type();
$(debug_assert_eq!($args.get_type(), llvm_f64);)*
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[$($args.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in [$($attributes),*] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[$($args.into()),*], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
};
}
generate_extern_fn!("unary", call_tan, "tan");
generate_extern_fn!("unary", call_asin, "asin");
generate_extern_fn!("unary", call_acos, "acos");
generate_extern_fn!("unary", call_atan, "atan");
generate_extern_fn!("unary", call_sinh, "sinh");
generate_extern_fn!("unary", call_cosh, "cosh");
generate_extern_fn!("unary", call_tanh, "tanh");
generate_extern_fn!("unary", call_asinh, "asinh");
generate_extern_fn!("unary", call_acosh, "acosh");
generate_extern_fn!("unary", call_atanh, "atanh");
generate_extern_fn!("unary", call_expm1, "expm1");
generate_extern_fn!(
"unary",
call_cbrt,
"cbrt",
"mustprogress",
"nofree",
"nosync",
"nounwind",
"readonly",
"willreturn"
);
generate_extern_fn!("unary", call_erf, "erf", "nounwind");
generate_extern_fn!("unary", call_erfc, "erfc", "nounwind");
generate_extern_fn!("unary", call_j1, "j1", "nounwind");
generate_extern_fn!("binary", call_atan2, "atan2");
generate_extern_fn!("binary", call_hypot, "hypot", "nounwind");
generate_extern_fn!("binary", call_nextafter, "nextafter", "nounwind");
/// Invokes the [`ldexp`](https://en.cppreference.com/w/c/numeric/math/ldexp) function.
pub fn call_ldexp<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
arg: FloatValue<'ctx>,
exp: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "ldexp";
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
debug_assert_eq!(arg.get_type(), llvm_f64);
debug_assert_eq!(exp.get_type(), llvm_i32);
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into(), llvm_i32.into()], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder
.build_call(extern_fn, &[arg.into(), exp.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Macro to generate `np_linalg` and `sp_linalg` functions
/// The function takes as input `NDArray` and returns ()
///
/// Arguments:
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$extern_fn:literal`: Name of underlying extern function
/// * (2/3/4): Number of `NDArray` that function takes as input
///
/// Note:
/// The operands and resulting `NDArray` are both passed as input to the funcion
/// It is the responsibility of caller to ensure that output `NDArray` is properly allocated on stack
/// The function changes the content of the output `NDArray` in-place
macro_rules! generate_linalg_extern_fn {
($fn_name:ident, $extern_fn:literal, 2) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2);
};
($fn_name:ident, $extern_fn:literal, 3) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3);
};
($fn_name:ident, $extern_fn:literal, 4) => {
generate_linalg_extern_fn!($fn_name, $extern_fn, mat1, mat2, mat3, mat4);
};
($fn_name:ident, $extern_fn:literal $(,$input_matrix:ident)*) => {
#[doc = concat!("Invokes the linalg `", stringify!($extern_fn), " function." )]
pub fn $fn_name<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>
$(,$input_matrix: BasicValueEnum<'ctx>)*,
name: Option<&str>,
){
const FN_NAME: &str = $extern_fn;
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let fn_type = ctx.ctx.void_type().fn_type(&[$($input_matrix.get_type().into()),*], false);
let func = ctx.module.add_function(FN_NAME, fn_type, None);
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
func
});
ctx.builder.build_call(extern_fn, &[$($input_matrix.into(),)*], name.unwrap_or_default()).unwrap();
}
};
}
generate_linalg_extern_fn!(call_np_linalg_cholesky, "np_linalg_cholesky", 2);
generate_linalg_extern_fn!(call_np_linalg_qr, "np_linalg_qr", 3);
generate_linalg_extern_fn!(call_np_linalg_svd, "np_linalg_svd", 4);
generate_linalg_extern_fn!(call_np_linalg_inv, "np_linalg_inv", 2);
generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
generate_linalg_extern_fn!(call_np_linalg_matrix_power, "np_linalg_matrix_power", 3);
generate_linalg_extern_fn!(call_np_linalg_det, "np_linalg_det", 2);
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);

View File

@ -1,15 +1,17 @@
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use nac3parser::ast::{Expr, Stmt, StrRef};
use crate::{ use crate::{
codegen::{expr::*, stmt::*, CodeGenContext}, codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, Type},
}; };
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValueEnum, PointerValue},
};
use nac3parser::ast::{Expr, Stmt, StrRef};
pub trait CodeGenerator { pub trait CodeGenerator {
/// Return the module name for the code generator. /// Return the module name for the code generator.
@ -22,9 +24,9 @@ pub trait CodeGenerator {
/// - fun: Function signature and definition ID. /// - fun: Function signature and definition ID.
/// - params: Function parameters. Note that this does not include the object even if the /// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method. /// function is a class method.
fn gen_call<'ctx, 'a>( fn gen_call<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
@ -36,12 +38,12 @@ pub trait CodeGenerator {
} }
/// Generate object constructor and returns the constructed object. /// Generate object constructor and returns the constructed object.
/// - signature: Function signature of the contructor. /// - signature: Function signature of the constructor.
/// - def: Class definition for the constructor class. /// - def: Class definition for the constructor class.
/// - params: Function parameters. /// - params: Function parameters.
fn gen_constructor<'ctx, 'a>( fn gen_constructor<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
signature: &FunSignature, signature: &FunSignature,
def: &TopLevelDef, def: &TopLevelDef,
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
@ -57,22 +59,23 @@ pub trait CodeGenerator {
/// - fun: Function signature, definition ID and the substitution key. /// - fun: Function signature, definition ID and the substitution key.
/// - params: Function parameters. Note that this does not include the object even if the /// - params: Function parameters. Note that this does not include the object even if the
/// function is a class method. /// function is a class method.
///
/// Note that this function should check if the function is generated in another thread (due to /// Note that this function should check if the function is generated in another thread (due to
/// possible race condition), see the default implementation for an example. /// possible race condition), see the default implementation for an example.
fn gen_func_instance<'ctx, 'a>( fn gen_func_instance<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, &mut TopLevelDef, String), fun: (&FunSignature, &mut TopLevelDef, String),
id: usize, id: usize,
) -> Result<String, String> { ) -> Result<String, String> {
gen_func_instance(ctx, obj, fun, id) gen_func_instance(ctx, &obj, fun, id)
} }
/// Generate the code for an expression. /// Generate the code for an expression.
fn gen_expr<'ctx, 'a>( fn gen_expr<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
expr: &Expr<Option<Type>>, expr: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> ) -> Result<Option<ValueEnum<'ctx>>, String>
where where
@ -83,44 +86,92 @@ pub trait CodeGenerator {
/// Allocate memory for a variable and return a pointer pointing to it. /// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function. /// The default implementation places the allocations at the start of the function.
fn gen_var_alloc<'ctx, 'a>( fn gen_var_alloc<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>, ty: BasicTypeEnum<'ctx>,
name: Option<&str>,
) -> Result<PointerValue<'ctx>, String> { ) -> Result<PointerValue<'ctx>, String> {
gen_var(ctx, ty) gen_var(ctx, ty, name)
}
/// Allocate memory for a variable and return a pointer pointing to it.
/// The default implementation places the allocations at the start of the function.
fn gen_array_var_alloc<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> Result<ArraySliceValue<'ctx>, String> {
gen_array_var(ctx, ty, size, name)
} }
/// Return a pointer pointing to the target of the expression. /// Return a pointer pointing to the target of the expression.
fn gen_store_target<'ctx, 'a>( fn gen_store_target<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
pattern: &Expr<Option<Type>>, pattern: &Expr<Option<Type>>,
) -> Result<PointerValue<'ctx>, String> name: Option<&str>,
) -> Result<Option<PointerValue<'ctx>>, String>
where where
Self: Sized, Self: Sized,
{ {
gen_store_target(self, ctx, pattern) gen_store_target(self, ctx, pattern, name)
} }
/// Generate code for an assignment expression. /// Generate code for an assignment expression.
fn gen_assign<'ctx, 'a>( fn gen_assign<'ctx>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>, target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>, value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String> ) -> Result<(), String>
where where
Self: Sized, Self: Sized,
{ {
gen_assign(self, ctx, target, value) gen_assign(self, ctx, target, value, value_ty)
}
/// Generate code for an assignment expression where LHS is a `"target_list"`.
///
/// See <https://docs.python.org/3/reference/simple_stmts.html#assignment-statements>.
fn gen_assign_target_list<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
targets: &Vec<Expr<Option<Type>>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_assign_target_list(self, ctx, targets, value, value_ty)
}
/// Generate code for an item assignment.
///
/// i.e., `target[key] = value`
fn gen_setitem<'ctx>(
&mut self,
ctx: &mut CodeGenContext<'ctx, '_>,
target: &Expr<Option<Type>>,
key: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
value_ty: Type,
) -> Result<(), String>
where
Self: Sized,
{
gen_setitem(self, ctx, target, key, value, value_ty)
} }
/// Generate code for a while expression. /// Generate code for a while expression.
/// Return true if the while loop must early return /// Return true if the while loop must early return
fn gen_while<'ctx, 'a>( fn gen_while(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> ) -> Result<(), String>
where where
@ -129,11 +180,11 @@ pub trait CodeGenerator {
gen_while(self, ctx, stmt) gen_while(self, ctx, stmt)
} }
/// Generate code for a while expression. /// Generate code for a for expression.
/// Return true if the while loop must early return /// Return true if the for loop must early return
fn gen_for<'ctx, 'a>( fn gen_for(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> ) -> Result<(), String>
where where
@ -144,9 +195,9 @@ pub trait CodeGenerator {
/// Generate code for an if expression. /// Generate code for an if expression.
/// Return true if the statement must early return /// Return true if the statement must early return
fn gen_if<'ctx, 'a>( fn gen_if(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> ) -> Result<(), String>
where where
@ -155,9 +206,9 @@ pub trait CodeGenerator {
gen_if(self, ctx, stmt) gen_if(self, ctx, stmt)
} }
fn gen_with<'ctx, 'a>( fn gen_with(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> ) -> Result<(), String>
where where
@ -167,10 +218,11 @@ pub trait CodeGenerator {
} }
/// Generate code for a statement /// Generate code for a statement
///
/// Return true if the statement must early return /// Return true if the statement must early return
fn gen_stmt<'ctx, 'a>( fn gen_stmt(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> ) -> Result<(), String>
where where
@ -178,6 +230,36 @@ pub trait CodeGenerator {
{ {
gen_stmt(self, ctx, stmt) gen_stmt(self, ctx, stmt)
} }
/// Generates code for a block statement.
fn gen_block<'a, I: Iterator<Item = &'a Stmt<Option<Type>>>>(
&mut self,
ctx: &mut CodeGenContext<'_, '_>,
stmts: I,
) -> Result<(), String>
where
Self: Sized,
{
gen_block(self, ctx, stmts)
}
/// See [`bool_to_i1`].
fn bool_to_i1<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value)
}
/// See [`bool_to_i8`].
fn bool_to_i8<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
}
} }
pub struct DefaultCodeGenerator { pub struct DefaultCodeGenerator {
@ -186,17 +268,20 @@ pub struct DefaultCodeGenerator {
} }
impl DefaultCodeGenerator { impl DefaultCodeGenerator {
#[must_use]
pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator { pub fn new(name: String, size_t: u32) -> DefaultCodeGenerator {
assert!(size_t == 32 || size_t == 64); assert!(matches!(size_t, 32 | 64));
DefaultCodeGenerator { name, size_t } DefaultCodeGenerator { name, size_t }
} }
} }
impl CodeGenerator for DefaultCodeGenerator { impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str { fn get_name(&self) -> &str {
&self.name &self.name
} }
/// Returns an LLVM integer type representing `size_t`.
fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> { fn get_size_type<'ctx>(&self, ctx: &'ctx Context) -> IntType<'ctx> {
// it should be unsigned, but we don't really need unsigned and this could save us from // it should be unsigned, but we don't really need unsigned and this could save us from
// having to do a bit cast... // having to do a bit cast...

View File

@ -1,140 +0,0 @@
typedef _ExtInt(8) int8_t;
typedef unsigned _ExtInt(8) uint8_t;
typedef _ExtInt(32) int32_t;
typedef unsigned _ExtInt(32) uint32_t;
typedef _ExtInt(64) int64_t;
typedef unsigned _ExtInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
T base, \
T exp \
) { \
T res = (T)1; \
/* repeated squaring method */ \
do { \
if (exp & 1) res *= base; /* for n odd */ \
exp >>= 1; \
base *= base; \
} while (exp); \
return res; \
} \
DEF_INT_EXP(int32_t)
DEF_INT_EXP(int64_t)
DEF_INT_EXP(uint32_t)
DEF_INT_EXP(uint64_t)
int32_t __nac3_slice_index_bound(int32_t i, const int32_t len) {
if (i < 0) {
i = len + i;
}
if (i < 0) {
return 0;
} else if (i > len) {
return len;
}
return i;
}
int32_t __nac3_range_slice_len(const int32_t start, const int32_t end, const int32_t step) {
int32_t diff = end - start;
if (diff > 0 && step > 0) {
return ((diff - 1) / step) + 1;
} else if (diff < 0 && step < 0) {
return ((diff + 1) / step) + 1;
} else {
return 0;
}
}
// Handle list assignment and dropping part of the list when
// both dest_step and src_step are +1.
// - All the index must *not* be out-of-bound or negative,
// - The end index is *inclusive*,
// - The length of src and dest slice size should already
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
int32_t __nac3_list_slice_assign_var_size(
int32_t dest_start,
int32_t dest_end,
int32_t dest_step,
uint8_t *dest_arr,
int32_t dest_arr_len,
int32_t src_start,
int32_t src_end,
int32_t src_step,
uint8_t *src_arr,
int32_t src_arr_len,
const int32_t size
) {
/* if dest_arr_len == 0, do nothing since we do not support extending list */
if (dest_arr_len == 0) return dest_arr_len;
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
if (src_step == dest_step && dest_step == 1) {
const int32_t src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
const int32_t dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
if (src_len > 0) {
__builtin_memmove(
dest_arr + dest_start * size,
src_arr + src_start * size,
src_len * size
);
}
if (dest_len > 0) {
/* dropping */
__builtin_memmove(
dest_arr + (dest_start + src_len) * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
}
/* shrink size */
return dest_arr_len - (dest_len - src_len);
}
/* if two range overlaps, need alloca */
uint8_t need_alloca =
(dest_arr == src_arr)
&& !(
MAX(dest_start, dest_end) < MIN(src_start, src_end)
|| MAX(src_start, src_end) < MIN(dest_start, dest_end)
);
if (need_alloca) {
uint8_t *tmp = __builtin_alloca(src_arr_len * size);
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
src_arr = tmp;
}
int32_t src_ind = src_start;
int32_t dest_ind = dest_start;
for (;
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
src_ind += src_step, dest_ind += dest_step
) {
/* for constant optimization */
if (size == 1) {
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
} else if (size == 4) {
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
} else if (size == 8) {
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
} else {
/* memcpy for var size, cannot overlap after previous alloca */
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
}
}
/* only dest_step == 1 can we shrink the dest list. */
/* size should be ensured prior to calling this function */
if (dest_step == 1 && dest_end >= dest_start) {
__builtin_memmove(
dest_arr + dest_ind * size,
dest_arr + (dest_end + 1) * size,
(dest_arr_len - dest_end - 1) * size
);
return dest_arr_len - (dest_end - dest_ind) - 1;
}
return dest_arr_len;
}

View File

@ -1,18 +1,30 @@
use crate::typecheck::typedef::Type;
use super::{CodeGenContext, CodeGenerator};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::Module, module::Module,
types::{BasicTypeEnum, IntType}, types::{BasicTypeEnum, IntType},
values::{IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
pub fn load_irrt(ctx: &Context) -> Module { use super::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics,
macros::codegen_unreachable,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
#[must_use]
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
let bitcode_buf = MemoryBuffer::create_from_memory_range( let bitcode_buf = MemoryBuffer::create_from_memory_range(
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")), include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
"irrt_bitcode_buffer", "irrt_bitcode_buffer",
@ -28,13 +40,33 @@ pub fn load_irrt(ctx: &Context) -> Module {
let function = irrt_mod.get_function(symbol).unwrap(); let function = irrt_mod.get_function(symbol).unwrap();
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0)); function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
} }
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
let exn_id_type = ctx.i32_type();
let errors = &[
("EXN_INDEX_ERROR", "0:IndexError"),
("EXN_VALUE_ERROR", "0:ValueError"),
("EXN_ASSERTION_ERROR", "0:AssertionError"),
("EXN_TYPE_ERROR", "0:TypeError"),
];
for (irrt_name, symbol_name) in errors {
let exn_id = symbol_resolver.get_string_id(symbol_name);
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
});
global.set_initializer(&exn_id);
}
irrt_mod irrt_mod
} }
// repeated squaring method adapted from GNU Scientific Library: // repeated squaring method adapted from GNU Scientific Library:
// https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c // https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
pub fn integer_power<'ctx, 'a>( pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
base: IntValue<'ctx>, base: IntValue<'ctx>,
exp: IntValue<'ctx>, exp: IntValue<'ctx>,
signed: bool, signed: bool,
@ -44,23 +76,42 @@ pub fn integer_power<'ctx, 'a>(
(64, 64, true) => "__nac3_int_exp_int64_t", (64, 64, true) => "__nac3_int_exp_int64_t",
(32, 32, false) => "__nac3_int_exp_uint32_t", (32, 32, false) => "__nac3_int_exp_uint32_t",
(64, 64, false) => "__nac3_int_exp_uint64_t", (64, 64, false) => "__nac3_int_exp_uint64_t",
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
let base_type = base.get_type(); let base_type = base.get_type();
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| { let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false); let fn_type = base_type.fn_type(&[base_type.into(), base_type.into()], false);
ctx.module.add_function(symbol, fn_type, None) ctx.module.add_function(symbol, fn_type, None)
}); });
// TODO: throw exception when exp < 0 // throw exception when exp < 0
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert(
generator,
ge_zero,
"0:ValueError",
"integer power must be positive or zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder ctx.builder
.build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow") .build_call(pow_fun, &[base.into(), exp.into()], "call_int_pow")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.into_int_value() .map(Either::unwrap_left)
.unwrap()
} }
pub fn calculate_len_for_slice_range<'ctx, 'a>( pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>, start: IntValue<'ctx>,
end: IntValue<'ctx>, end: IntValue<'ctx>,
step: IntValue<'ctx>, step: IntValue<'ctx>,
@ -72,13 +123,25 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>(
ctx.module.add_function(SYMBOL, fn_t, None) ctx.module.add_function(SYMBOL, fn_t, None)
}); });
// TODO: assert step != 0, throw exception if not // assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"step must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder ctx.builder
.build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len") .build_call(len_func, &[start.into(), end.into(), step.into()], "calc_len")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap() .unwrap()
.into_int_value()
} }
/// NOTE: the output value of the end index of this function should be compared ***inclusively***, /// NOTE: the output value of the end index of this function should be compared ***inclusively***,
@ -121,95 +184,140 @@ pub fn calculate_len_for_slice_range<'ctx, 'a>(
/// ,step /// ,step
/// ) /// )
/// ``` /// ```
pub fn handle_slice_indices<'a, 'ctx, G: CodeGenerator>( pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
start: &Option<Box<Expr<Option<Type>>>>, start: &Option<Box<Expr<Option<Type>>>>,
end: &Option<Box<Expr<Option<Type>>>>, end: &Option<Box<Expr<Option<Type>>>>,
step: &Option<Box<Expr<Option<Type>>>>, step: &Option<Box<Expr<Option<Type>>>>,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G, generator: &mut G,
list: PointerValue<'ctx>, length: IntValue<'ctx>,
) -> Result<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), String> { ) -> Result<Option<(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)>, String> {
// TODO: throw exception when step is 0
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let length = ctx.build_gep_and_load(list, &[zero, one]).into_int_value(); let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32").unwrap();
let length = ctx.builder.build_int_truncate_or_bit_cast(length, int32, "leni32"); Ok(Some(match (start, end, step) {
Ok(match (start, end, step) {
(s, e, None) => ( (s, e, None) => (
s.as_ref().map_or_else( if let Some(s) = s.as_ref() {
|| Ok(int32.const_zero()), match handle_slice_index_bound(s, ctx, generator, length)? {
|s| handle_slice_index_bound(s, ctx, generator, length), Some(v) => v,
)?, None => return Ok(None),
}
} else {
int32.const_zero()
},
{ {
let e = e.as_ref().map_or_else( let e = if let Some(s) = e.as_ref() {
|| Ok(length), match handle_slice_index_bound(s, ctx, generator, length)? {
|e| handle_slice_index_bound(e, ctx, generator, length), Some(v) => v,
)?; None => return Ok(None),
ctx.builder.build_int_sub(e, one, "final_end") }
} else {
length
};
ctx.builder.build_int_sub(e, one, "final_end").unwrap()
}, },
one, one,
), ),
(s, e, Some(step)) => { (s, e, Some(step)) => {
let step = generator let step = if let Some(v) = generator.gen_expr(ctx, step)? {
.gen_expr(ctx, step)? v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
.unwrap() } else {
.to_basic_value_enum(ctx, generator)? return Ok(None);
.into_int_value(); };
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1"); // assert step != 0, throw exception if not
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg"); let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"slice step cannot be zero",
[None, None, None],
ctx.current_loc,
);
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
( (
match s { match s {
Some(s) => { Some(s) => {
let s = handle_slice_index_bound(s, ctx, generator, length)?; let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None);
};
ctx.builder ctx.builder
.build_select( .build_select(
ctx.builder.build_and( ctx.builder
ctx.builder.build_int_compare( .build_and(
IntPredicate::EQ, ctx.builder
s, .build_int_compare(
length, IntPredicate::EQ,
"s_eq_len", s,
), length,
neg, "s_eq_len",
"should_minus_one", )
), .unwrap(),
ctx.builder.build_int_sub(s, one, "s_min"), neg,
"should_minus_one",
)
.unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
s, s,
"final_start", "final_start",
) )
.into_int_value() .map(BasicValueEnum::into_int_value)
.unwrap()
} }
None => ctx.builder.build_select(neg, len_id, zero, "stt").into_int_value(), None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
}, },
match e { match e {
Some(e) => { Some(e) => {
let e = handle_slice_index_bound(e, ctx, generator, length)?; let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None);
};
ctx.builder ctx.builder
.build_select( .build_select(
neg, neg,
ctx.builder.build_int_add(e, one, "end_add_one"), ctx.builder.build_int_add(e, one, "end_add_one").unwrap(),
ctx.builder.build_int_sub(e, one, "end_sub_one"), ctx.builder.build_int_sub(e, one, "end_sub_one").unwrap(),
"final_end", "final_end",
) )
.into_int_value() .map(BasicValueEnum::into_int_value)
.unwrap()
} }
None => ctx.builder.build_select(neg, zero, len_id, "end").into_int_value(), None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
}, },
step, step,
) )
} }
}) }))
} }
/// this function allows index out of range, since python /// this function allows index out of range, since python
/// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`). /// allows index out of range in slice (`a = [1,2,3]; a[1:10] == [2,3]`).
pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>( pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
i: &Expr<Option<Type>>, i: &Expr<Option<Type>>,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut G, generator: &mut G,
length: IntValue<'ctx>, length: IntValue<'ctx>,
) -> Result<IntValue<'ctx>, String> { ) -> Result<Option<IntValue<'ctx>>, String> {
const SYMBOL: &str = "__nac3_slice_index_bound"; const SYMBOL: &str = "__nac3_slice_index_bound";
let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| { let func = ctx.module.get_function(SYMBOL).unwrap_or_else(|| {
let i32_t = ctx.ctx.i32_type(); let i32_t = ctx.ctx.i32_type();
@ -217,29 +325,35 @@ pub fn handle_slice_index_bound<'a, 'ctx, G: CodeGenerator>(
ctx.module.add_function(SYMBOL, fn_t, None) ctx.module.add_function(SYMBOL, fn_t, None)
}); });
let i = generator.gen_expr(ctx, i)?.unwrap().to_basic_value_enum(ctx, generator)?; let i = if let Some(v) = generator.gen_expr(ctx, i)? {
Ok(ctx v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
.builder } else {
.build_call(func, &[i.into(), length.into()], "bounded_ind") return Ok(None);
.try_as_basic_value() };
.left() Ok(Some(
.unwrap() ctx.builder
.into_int_value()) .build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(),
))
} }
/// This function handles 'end' **inclusively**. /// This function handles 'end' **inclusively**.
/// Order of tuples assign_idx and value_idx is ('start', 'end', 'step'). /// Order of tuples `assign_idx` and `value_idx` is ('start', 'end', 'step').
/// Negative index should be handled before entering this function /// Negative index should be handled before entering this function
pub fn list_slice_assignment<'ctx, 'a>( pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut G,
size_ty: IntType<'ctx>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: BasicTypeEnum<'ctx>, ty: BasicTypeEnum<'ctx>,
dest_arr: PointerValue<'ctx>, dest_arr: ListValue<'ctx>,
dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), dest_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
src_arr: PointerValue<'ctx>, src_arr: ListValue<'ctx>,
src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>), src_idx: (IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>),
) { ) {
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::Generic); let size_ty = generator.get_size_type(ctx.ctx);
let int8_ptr = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr); let (fun_symbol, elem_ptr_type) = ("__nac3_list_slice_assign_var_size", int8_ptr);
let slice_assign_fun = { let slice_assign_fun = {
@ -264,26 +378,72 @@ pub fn list_slice_assignment<'ctx, 'a>(
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let dest_arr_ptr = ctx.build_gep_and_load(dest_arr, &[zero, zero]); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr = ctx.builder.build_pointer_cast( let dest_arr_ptr =
dest_arr_ptr.into_pointer_value(), ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
elem_ptr_type, let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
"dest_arr_ptr_cast", let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let dest_len = ctx.build_gep_and_load(dest_arr, &[zero, one]).into_int_value(); let src_arr_ptr =
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32"); ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_arr_ptr = ctx.build_gep_and_load(src_arr, &[zero, zero]); let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_arr_ptr = ctx.builder.build_pointer_cast( let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
src_arr_ptr.into_pointer_value(),
elem_ptr_type,
"src_arr_ptr_cast",
);
let src_len = ctx.build_gep_and_load(src_arr, &[zero, one]).into_int_value();
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32");
// index in bound and positive should be done // index in bound and positive should be done
// TODO: assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied // throw exception if not satisfied
let src_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let dest_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let src_slice_len =
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
generator,
cond,
"0:ValueError",
"attempt to assign sequence of size {0} to slice of size {1} with step size {2}",
[Some(src_slice_len), Some(dest_slice_len), Some(dest_idx.2)],
ctx.current_loc,
);
let new_len = { let new_len = {
let args = vec![ let args = vec![
dest_idx.0.into(), // dest start idx dest_idx.0.into(), // dest start idx
@ -302,29 +462,491 @@ pub fn list_slice_assignment<'ctx, 'a>(
BasicTypeEnum::IntType(t) => t.size_of(), BasicTypeEnum::IntType(t) => t.size_of(),
BasicTypeEnum::PointerType(t) => t.size_of(), BasicTypeEnum::PointerType(t) => t.size_of(),
BasicTypeEnum::StructType(t) => t.size_of().unwrap(), BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
_ => unreachable!(), _ => codegen_unreachable!(ctx),
}; };
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size") ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
} }
.into(), .into(),
]; ];
ctx.builder ctx.builder
.build_call(slice_assign_fun, args.as_slice(), "slice_assign") .build_call(slice_assign_fun, args.as_slice(), "slice_assign")
.try_as_basic_value() .map(CallSiteValue::try_as_basic_value)
.unwrap_left() .map(|v| v.map_left(BasicValueEnum::into_int_value))
.into_int_value() .map(Either::unwrap_left)
.unwrap()
}; };
// update length // update length
let need_update = let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update"); ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update"); let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont"); let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb); ctx.builder.position_at_end(update_bb);
let dest_len_ptr = unsafe { ctx.builder.build_gep(dest_arr, &[zero, one], "dest_len_ptr") }; let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len"); dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_store(dest_len_ptr, new_len); ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.build_unconditional_branch(cont_bb);
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
} }
/// Generates a call to `isinf` in IR. Returns an `i1` representing the result.
pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isinf").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `isnan` in IR. Returns an `i1` representing the result.
pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> IntValue<'ctx> {
let intrinsic_fn = ctx.module.get_function("__nac3_isnan").unwrap_or_else(|| {
let fn_type = ctx.ctx.i32_type().fn_type(&[ctx.ctx.f64_type().into()], false);
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
generator.bool_to_i1(ctx, ret)
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gamma", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gamma")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_gammaln", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "gammaln")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("__nac3_j0", fn_type, None)
});
ctx.builder
.build_call(intrinsic_fn, &[v.into()], "j0")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension
/// respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -0,0 +1,345 @@
use inkwell::{
context::Context,
intrinsics::Intrinsic,
types::{AnyTypeEnum::IntType, FloatType},
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue},
AddressSpace,
};
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16";
}
if ft == ctx.f32_type() {
return "f32";
}
if ft == ctx.f64_type() {
return "f64";
}
if ft == ctx.f128_type() {
return "f128";
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80";
}
if ft == ctx.ppc_f128_type() {
return "ppcf128";
}
unreachable!()
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_start<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_start";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.va_start`](https://llvm.org/docs/LangRef.html#llvm-va-start-intrinsic)
/// intrinsic.
pub fn call_va_end<'ctx>(ctx: &CodeGenContext<'ctx, '_>, arglist: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.va_end";
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[arglist.into()], "").unwrap();
}
/// Invokes the [`llvm.stacksave`](https://llvm.org/docs/LangRef.html#llvm-stacksave-intrinsic)
/// intrinsic.
pub fn call_stacksave<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
name: Option<&str>,
) -> PointerValue<'ctx> {
const FN_NAME: &str = "llvm.stacksave";
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_pointer_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Invokes the
/// [`llvm.stackrestore`](https://llvm.org/docs/LangRef.html#llvm-stackrestore-intrinsic) intrinsic.
///
/// - `ptr`: The pointer storing the address to restore the stack to.
pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>) {
const FN_NAME: &str = "llvm.stackrestore";
/*
SEE https://github.com/TheDan64/inkwell/issues/496
We want `llvm.stackrestore`, but the following would generate `llvm.stackrestore.p0i8`.
```ignore
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
```
Temp workaround by manually declaring the intrinsic with the correct function name instead.
*/
let intrinsic_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
let llvm_void = ctx.ctx.void_type();
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let fn_type = llvm_void.fn_type(&[llvm_p0i8.into()], false);
ctx.module.add_function(FN_NAME, fn_type, None)
});
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
}
/// Invokes the [`llvm.memcpy`](https://llvm.org/docs/LangRef.html#llvm-memcpy-intrinsic) intrinsic.
///
/// * `dest` - The pointer to the destination. Must be a pointer to an integer type.
/// * `src` - The pointer to the source. Must be a pointer to an integer type.
/// * `len` - The number of bytes to copy.
/// * `is_volatile` - Whether the `memcpy` operation should be `volatile`.
pub fn call_memcpy<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
const FN_NAME: &str = "llvm.memcpy";
debug_assert!(dest.get_type().get_element_type().is_int_type());
debug_assert!(src.get_type().get_element_type().is_int_type());
debug_assert_eq!(
dest.get_type().get_element_type().into_int_type().get_bit_width(),
src.get_type().get_element_type().into_int_type().get_bit_width(),
);
debug_assert!(matches!(len.get_type().get_bit_width(), 32 | 64));
debug_assert_eq!(is_volatile.get_type().get_bit_width(), 1);
let llvm_dest_t = dest.get_type();
let llvm_src_t = src.get_type();
let llvm_len_t = len.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(
&ctx.module,
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
)
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[dest.into(), src.into(), len.into(), is_volatile.into()], "")
.unwrap();
}
/// Invokes the `llvm.memcpy` intrinsic.
///
/// Unlike [`call_memcpy`], this function accepts any type of pointer value. If `dest` or `src` is
/// not a pointer to an integer, the pointer(s) will be cast to `i8*` before invoking `memcpy`.
pub fn call_memcpy_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
dest: PointerValue<'ctx>,
src: PointerValue<'ctx>,
len: IntValue<'ctx>,
is_volatile: IntValue<'ctx>,
) {
let llvm_i8 = ctx.ctx.i8_type();
let llvm_p0i8 = llvm_i8.ptr_type(AddressSpace::default());
let dest_elem_t = dest.get_type().get_element_type();
let src_elem_t = src.get_type().get_element_type();
let dest = if matches!(dest_elem_t, IntType(t) if t.get_bit_width() == 8) {
dest
} else {
ctx.builder
.build_bit_cast(dest, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
let src = if matches!(src_elem_t, IntType(t) if t.get_bit_width() == 8) {
src
} else {
ctx.builder
.build_bit_cast(src, llvm_p0i8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap()
};
call_memcpy(ctx, dest, src, len, is_volatile);
}
/// Macro to find and generate build call for llvm intrinsic (body of llvm intrinsic function)
///
/// Arguments:
/// * `$ctx:ident`: Reference to the current Code Generation Context
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
/// Use `BasicValueEnum::into_int_value` for Integer return type and
/// `BasicValueEnum::into_float_value` for Float return type
/// * `$llvm_ty:ident`: Type of first operand
/// * `,($val:ident)*`: Comma separated list of operands
macro_rules! generate_llvm_intrinsic_fn_body {
($ctx:ident, $name:ident, $llvm_name:literal, $map_fn:expr, $llvm_ty:ident $(,$val:ident)*) => {{
const FN_NAME: &str = concat!("llvm.", $llvm_name);
let intrinsic_fn = Intrinsic::find(FN_NAME).and_then(|intrinsic| intrinsic.get_declaration(&$ctx.module, &[$llvm_ty.into()])).unwrap();
$ctx.builder.build_call(intrinsic_fn, &[$($val.into()),*], $name.unwrap_or_default()).map(CallSiteValue::try_as_basic_value).map(|v| v.map_left($map_fn)).map(Either::unwrap_left).unwrap()
}};
}
/// Macro to generate the llvm intrinsic function using [`generate_llvm_intrinsic_fn_body`].
///
/// Arguments:
/// * `float/int`: Indicates the return and argument type of the function
/// * `$fn_name:ident`: The identifier of the rust function to be generated
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
/// * `$val:ident`: The operand for unary operations
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
macro_rules! generate_llvm_intrinsic_fn {
("float", $fn_name:ident, $llvm_name:literal, $val:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
let llvm_ty = $val.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val)
}
};
("float", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: FloatValue<'ctx>,
$val2: FloatValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
debug_assert_eq!($val1.get_type(), $val2.get_type());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_float_value, llvm_ty, $val1, $val2)
}
};
("int", $fn_name:ident, $llvm_name:literal, $val1:ident, $val2:ident) => {
#[doc = concat!("Invokes the [`", stringify!($llvm_name), "`](https://llvm.org/docs/LangRef.html#llvm-", stringify!($llvm_name), "-intrinsic) intrinsic." )]
pub fn $fn_name<'ctx> (
ctx: &CodeGenContext<'ctx, '_>,
$val1: IntValue<'ctx>,
$val2: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!($val1.get_type().get_bit_width(), $val2.get_type().get_bit_width());
let llvm_ty = $val1.get_type();
generate_llvm_intrinsic_fn_body!(ctx, name, $llvm_name, BasicValueEnum::into_int_value, llvm_ty, $val1, $val2)
}
};
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
///
/// * `src` - The value for which the absolute value is to be returned.
/// * `is_int_min_poison` - Whether `poison` is to be returned if `src` is `INT_MIN`.
pub fn call_int_abs<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
src: IntValue<'ctx>,
is_int_min_poison: IntValue<'ctx>,
name: Option<&str>,
) -> IntValue<'ctx> {
debug_assert_eq!(is_int_min_poison.get_type().get_bit_width(), 1);
debug_assert!(is_int_min_poison.is_const());
let src_type = src.get_type();
generate_llvm_intrinsic_fn_body!(
ctx,
name,
"abs",
BasicValueEnum::into_int_value,
src_type,
src,
is_int_min_poison
)
}
generate_llvm_intrinsic_fn!("int", call_int_smax, "smax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_smin, "smin", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umax, "umax", a, b);
generate_llvm_intrinsic_fn!("int", call_int_umin, "umin", a, b);
generate_llvm_intrinsic_fn!("int", call_expect, "expect", val, expected_val);
generate_llvm_intrinsic_fn!("float", call_float_sqrt, "sqrt", val);
generate_llvm_intrinsic_fn!("float", call_float_sin, "sin", val);
generate_llvm_intrinsic_fn!("float", call_float_cos, "cos", val);
generate_llvm_intrinsic_fn!("float", call_float_pow, "pow", val, power);
generate_llvm_intrinsic_fn!("float", call_float_exp, "exp", val);
generate_llvm_intrinsic_fn!("float", call_float_exp2, "exp2", val);
generate_llvm_intrinsic_fn!("float", call_float_log, "log", val);
generate_llvm_intrinsic_fn!("float", call_float_log10, "log10", val);
generate_llvm_intrinsic_fn!("float", call_float_log2, "log2", val);
generate_llvm_intrinsic_fn!("float", call_float_fabs, "fabs", src);
generate_llvm_intrinsic_fn!("float", call_float_minnum, "minnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_maxnum, "maxnum", val, power);
generate_llvm_intrinsic_fn!("float", call_float_copysign, "copysign", mag, sgn);
generate_llvm_intrinsic_fn!("float", call_float_floor, "floor", val);
generate_llvm_intrinsic_fn!("float", call_float_ceil, "ceil", val);
generate_llvm_intrinsic_fn!("float", call_float_round, "round", val);
generate_llvm_intrinsic_fn!("float", call_float_rint, "rint", val);
/// Invokes the [`llvm.powi`](https://llvm.org/docs/LangRef.html#llvm-powi-intrinsic) intrinsic.
pub fn call_float_powi<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
val: FloatValue<'ctx>,
power: IntValue<'ctx>,
name: Option<&str>,
) -> FloatValue<'ctx> {
const FN_NAME: &str = "llvm.powi";
let llvm_val_t = val.get_type();
let llvm_power_t = power.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| {
intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
})
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[val.into(), power.into()], name.unwrap_or_default())
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,25 +1,37 @@
use crate::{ use std::{
codegen::{ collections::{HashMap, HashSet},
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenTask, DefaultCodeGenerator, sync::Arc,
WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::TopLevelComposer, DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
},
}; };
use indexmap::IndexMap;
use indoc::indoc; use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel,
};
use nac3parser::{ use nac3parser::{
ast::{fold::Fold, StrRef}, ast::{fold::Fold, FileName, StrRef},
parser::parse_program, parser::parse_program,
}; };
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc; use crate::{
codegen::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, IdentifierInfo, Inferencer, PrimitiveStore},
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
};
struct Resolver { struct Resolver {
id_to_type: HashMap<StrRef, Type>, id_to_type: HashMap<StrRef, Type>,
@ -48,23 +60,24 @@ impl SymbolResolver for Resolver {
_: &PrimitiveStore, _: &PrimitiveStore,
str: StrRef, str: StrRef,
) -> Result<Type, String> { ) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str)) self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
} }
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, 'a>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def self.id_to_def
.read() .read()
.get(&id) .get(&id)
.cloned() .copied()
.ok_or_else(|| format!("cannot find symbol `{}`", id)) .ok_or_else(|| HashSet::from([format!("cannot find symbol `{id}`")]))
} }
fn get_string_id(&self, _: &str) -> i32 { fn get_string_id(&self, _: &str) -> i32 {
@ -83,9 +96,9 @@ fn test_primitives() {
d = a if c == 1 else 0 d = a if c == 1 else 0
return d return d
"}; "};
let statements = parse_program(source, Default::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let composer: TopLevelComposer = Default::default(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -94,17 +107,27 @@ fn test_primitives() {
let resolver = Arc::new(Resolver { let resolver = Arc::new(Resolver {
id_to_type: HashMap::new(), id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(), class_names: HashMap::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()]; let threads = vec![DefaultCodeGenerator::new("test".into(), 32).into()];
let signature = FunSignature { let signature = FunSignature {
args: vec![ args: vec![
FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }, FuncArg {
FuncArg { name: "b".into(), ty: primitives.int32, default_value: None }, name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
FuncArg {
name: "b".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
},
], ],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: VarMap::new(),
}; };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -119,12 +142,13 @@ fn test_primitives() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "b".into()].iter().cloned().collect(); let mut identifiers: HashMap<_, _> =
["a".into(), "b".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
unifier: &mut unifier, unifier: &mut unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
@ -148,7 +172,7 @@ fn test_primitives() {
}); });
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Vec::default(),
symbol_name: "testing".into(), symbol_name: "testing".into(),
body: Arc::new(statements), body: Arc::new(statements),
unifier_index: 0, unifier_index: 0,
@ -180,28 +204,48 @@ fn test_primitives() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
define i32 @testing(i32 %0, i32 %1) { ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0, i32 %1) local_unnamed_addr #0 !dbg !4 {
init: init:
%add = add i32 %0, %1 %add = add i32 %1, %0, !dbg !9
%cmp = icmp eq i32 %add, 1 %cmp = icmp eq i32 %add, 1, !dbg !10
br i1 %cmp, label %then, label %else %. = select i1 %cmp, i32 %0, i32 0, !dbg !11
ret i32 %., !dbg !12
then: ; preds = %init
br label %cont
else: ; preds = %init
br label %cont
cont: ; preds = %else, %then
%if_exp_result.0 = phi i32 [ %0, %then ], [ 0, %else ]
ret i32 %if_exp_result.0
} }
"}
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2}
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !5, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !8)
!5 = !DISubroutineType(flags: DIFlagPublic, types: !6)
!6 = !{!7}
!7 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!8 = !{}
!9 = !DILocation(line: 1, column: 9, scope: !4)
!10 = !DILocation(line: 2, column: 15, scope: !4)
!11 = !DILocation(line: 0, scope: !4)
!12 = !DILocation(line: 3, column: 8, scope: !4)
"}
.trim(); .trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
}))); })));
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, f);
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
@ -212,23 +256,28 @@ fn test_simple_call() {
a = foo(a) a = foo(a)
return a * 2 return a * 2
"}; "};
let statements_1 = parse_program(source_1, Default::default()).unwrap(); let statements_1 = parse_program(source_1, FileName::default()).unwrap();
let source_2 = indoc! { " let source_2 = indoc! { "
return a + 1 return a + 1
"}; "};
let statements_2 = parse_program(source_2, Default::default()).unwrap(); let statements_2 = parse_program(source_2, FileName::default()).unwrap();
let composer: TopLevelComposer = Default::default(); let composer = TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 32).0;
let mut unifier = composer.unifier.clone(); let mut unifier = composer.unifier.clone();
let primitives = composer.primitives_ty; let primitives = composer.primitives_ty;
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
unifier.top_level = Some(top_level.clone()); unifier.top_level = Some(top_level.clone());
let signature = FunSignature { let signature = FunSignature {
args: vec![FuncArg { name: "a".into(), ty: primitives.int32, default_value: None }], args: vec![FuncArg {
name: "a".into(),
ty: primitives.int32,
default_value: None,
is_vararg: false,
}],
ret: primitives.int32, ret: primitives.int32,
vars: HashMap::new(), vars: VarMap::new(),
}; };
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone())); let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -252,7 +301,7 @@ fn test_simple_call() {
let resolver = Resolver { let resolver = Resolver {
id_to_type: HashMap::new(), id_to_type: HashMap::new(),
id_to_def: RwLock::new(HashMap::new()), id_to_def: RwLock::new(HashMap::new()),
class_names: Default::default(), class_names: HashMap::default(),
}; };
resolver.add_id_def("foo".into(), DefinitionId(foo_id)); resolver.add_id_def("foo".into(), DefinitionId(foo_id));
let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>; let resolver = Arc::new(resolver) as Arc<dyn SymbolResolver + Send + Sync>;
@ -273,12 +322,13 @@ fn test_simple_call() {
}; };
let mut virtual_checks = Vec::new(); let mut virtual_checks = Vec::new();
let mut calls = HashMap::new(); let mut calls = HashMap::new();
let mut identifiers: HashSet<_> = ["a".into(), "foo".into()].iter().cloned().collect(); let mut identifiers: HashMap<_, _> =
["a".into(), "foo".into()].map(|id| (id, IdentifierInfo::default())).into();
let mut inferencer = Inferencer { let mut inferencer = Inferencer {
top_level: &top_level, top_level: &top_level,
function_data: &mut function_data, function_data: &mut function_data,
unifier: &mut unifier, unifier: &mut unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &primitives, primitives: &primitives,
virtual_checks: &mut virtual_checks, virtual_checks: &mut virtual_checks,
calls: &mut calls, calls: &mut calls,
@ -307,11 +357,11 @@ fn test_simple_call() {
&mut *top_level.definitions.read()[foo_id].write() &mut *top_level.definitions.read()[foo_id].write()
{ {
instance_to_stmt.insert( instance_to_stmt.insert(
"".to_string(), String::new(),
FunInstance { FunInstance {
body: Arc::new(statements_2), body: Arc::new(statements_2),
calls: Arc::new(inferencer.calls.clone()), calls: Arc::new(inferencer.calls.clone()),
subst: Default::default(), subst: IndexMap::default(),
unifier_id: 0, unifier_id: 0,
}, },
); );
@ -327,7 +377,7 @@ fn test_simple_call() {
}); });
let task = CodeGenTask { let task = CodeGenTask {
subst: Default::default(), subst: Vec::default(),
symbol_name: "testing".to_string(), symbol_name: "testing".to_string(),
body: Arc::new(statements_1), body: Arc::new(statements_1),
calls: Arc::new(calls1), calls: Arc::new(calls1),
@ -341,24 +391,86 @@ fn test_simple_call() {
let expected = indoc! {" let expected = indoc! {"
; ModuleID = 'test' ; ModuleID = 'test'
source_filename = \"test\" source_filename = \"test\"
target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128\"
target triple = \"x86_64-unknown-linux-gnu\"
define i32 @testing(i32 %0) { ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @testing(i32 %0) local_unnamed_addr #0 !dbg !5 {
init: init:
%call = call i32 @foo.0(i32 %0) %add.i = shl i32 %0, 1, !dbg !10
%mul = mul i32 %call, 2 %mul = add i32 %add.i, 2, !dbg !10
ret i32 %mul ret i32 %mul, !dbg !10
} }
define i32 @foo.0(i32 %0) { ; Function Attrs: mustprogress nofree norecurse nosync nounwind readnone willreturn
define i32 @foo.0(i32 %0) local_unnamed_addr #0 !dbg !11 {
init: init:
%add = add i32 %0, 1 %add = add i32 %0, 1, !dbg !12
ret i32 %add ret i32 %add, !dbg !12
} }
"}
attributes #0 = { mustprogress nofree norecurse nosync nounwind readnone willreturn }
!llvm.module.flags = !{!0, !1}
!llvm.dbg.cu = !{!2, !4}
!0 = !{i32 2, !\"Debug Info Version\", i32 3}
!1 = !{i32 2, !\"Dwarf Version\", i32 4}
!2 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!3 = !DIFile(filename: \"unknown\", directory: \"\")
!4 = distinct !DICompileUnit(language: DW_LANG_Python, file: !3, producer: \"NAC3\", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug)
!5 = distinct !DISubprogram(name: \"testing\", linkageName: \"testing\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !2, retainedNodes: !9)
!6 = !DISubroutineType(flags: DIFlagPublic, types: !7)
!7 = !{!8}
!8 = !DIBasicType(name: \"_\", flags: DIFlagPublic)
!9 = !{}
!10 = !DILocation(line: 2, column: 12, scope: !5)
!11 = distinct !DISubprogram(name: \"foo.0\", linkageName: \"foo.0\", scope: null, file: !3, line: 1, type: !6, scopeLine: 1, flags: DIFlagPublic, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !4, retainedNodes: !9)
!12 = !DILocation(line: 1, column: 12, scope: !11)
"}
.trim(); .trim();
assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); assert_eq!(expected, module.print_to_string().to_str().unwrap().trim());
}))); })));
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, f);
Target::initialize_all(&InitializationConfig::default());
let llvm_options = CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
#[test]
fn test_classes_list_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_list = ListType::new(&generator, &ctx, llvm_i32.into());
assert!(ListType::is_type(llvm_list.as_base_type(), llvm_usize).is_ok());
}
#[test]
fn test_classes_range_type_new() {
let ctx = inkwell::context::Context::create();
let llvm_range = RangeType::new(&ctx);
assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok());
}
#[test]
fn test_classes_ndarray_type_new() {
let ctx = inkwell::context::Context::create();
let generator = DefaultCodeGenerator::new(String::new(), 64);
let llvm_i32 = ctx.i32_type();
let llvm_usize = generator.get_size_type(&ctx);
let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into());
assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok());
}

View File

@ -1,5 +1,27 @@
#![warn(clippy::all)] #![deny(
#![allow(dead_code)] future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
dead_code,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
// users of nac3core need to use the same version of these dependencies, so expose them as nac3core::*
pub use inkwell;
pub use nac3parser;
pub mod codegen; pub mod codegen;
pub mod symbol_resolver; pub mod symbol_resolver;

View File

@ -1,23 +1,24 @@
use std::fmt::Debug; use std::{
use std::sync::Arc; collections::{HashMap, HashSet},
use std::{collections::HashMap, fmt::Display}; fmt::{Debug, Display},
rc::Rc,
use crate::typecheck::typedef::TypeEnum; sync::Arc,
use crate::{
codegen::CodeGenContext,
toplevel::{DefinitionId, TopLevelDef},
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip, Itertools};
use parking_lot::RwLock;
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use crate::{ use crate::{
codegen::CodeGenerator, codegen::{CodeGenContext, CodeGenerator},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, Unifier}, typedef::{Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, izip};
use nac3parser::ast::{Expr, Location, StrRef};
use parking_lot::RwLock;
#[derive(Clone, PartialEq, Debug)] #[derive(Clone, PartialEq, Debug)]
pub enum SymbolValue { pub enum SymbolValue {
@ -33,15 +34,192 @@ pub enum SymbolValue {
OptionNone, OptionNone,
} }
impl SymbolValue {
/// Creates a [`SymbolValue`] from a [`Constant`].
///
/// * `constant` - The constant to create the value from.
/// * `expected_ty` - The expected type of the [`SymbolValue`].
pub fn from_constant(
constant: &Constant,
expected_ty: Type,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Result<Self, String> {
match constant {
Constant::None => {
if unifier.unioned(expected_ty, primitives.option) {
Ok(SymbolValue::OptionNone)
} else {
Err(format!("Expected {expected_ty:?}, but got Option"))
}
}
Constant::Bool(b) => {
if unifier.unioned(expected_ty, primitives.bool) {
Ok(SymbolValue::Bool(*b))
} else {
Err(format!("Expected {expected_ty:?}, but got bool"))
}
}
Constant::Str(s) => {
if unifier.unioned(expected_ty, primitives.str) {
Ok(SymbolValue::Str(s.to_string()))
} else {
Err(format!("Expected {expected_ty:?}, but got str"))
}
}
Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
} else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
}
}
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty, is_vararg_ctx } = expected_ty.as_ref() else {
return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
};
assert!(*is_vararg_ctx || ty.len() == t.len());
let elems = t
.iter()
.zip(ty)
.map(|(constant, ty)| Self::from_constant(constant, *ty, primitives, unifier))
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => {
if unifier.unioned(expected_ty, primitives.float) {
Ok(SymbolValue::Double(*f))
} else {
Err(format!("Expected {expected_ty:?}, but got float"))
}
}
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
Constant::Str(s) => Ok(SymbolValue::Str(s.to_string())),
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i)
.map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} else {
u32::try_from(i)
.map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
}
}
Constant::Tuple(t) => {
let elems = t
.iter()
.map(Self::from_constant_inferred)
.collect::<Result<Vec<SymbolValue>, _>>()?;
Ok(SymbolValue::Tuple(elems))
}
Constant::Float(f) => Ok(SymbolValue::Double(*f)),
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
/// Returns the [`Type`] representing the data type of this value.
pub fn get_type(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Type {
match self {
SymbolValue::I32(_) => primitives.int32,
SymbolValue::I64(_) => primitives.int64,
SymbolValue::U32(_) => primitives.uint32,
SymbolValue::U64(_) => primitives.uint64,
SymbolValue::Str(_) => primitives.str,
SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys, is_vararg_ctx: false })
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
}
}
/// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
match self {
SymbolValue::Bool(..)
| SymbolValue::Double(..)
| SymbolValue::I32(..)
| SymbolValue::I64(..)
| SymbolValue::U32(..)
| SymbolValue::U64(..)
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| v.get_type_annotation(primitives, unifier))
.collect::<Vec<_>>();
TypeAnnotation::Tuple(vs_tys)
}
SymbolValue::OptionNone => TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
params: Vec::default(),
},
SymbolValue::OptionSome(v) => {
let ty = v.get_type_annotation(primitives, unifier);
TypeAnnotation::CustomClass {
id: primitives.option.obj_id(unifier).unwrap(),
params: vec![ty],
}
}
}
}
/// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty)
}
}
impl Display for SymbolValue { impl Display for SymbolValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
SymbolValue::I32(i) => write!(f, "{}", i), SymbolValue::I32(i) => write!(f, "{i}"),
SymbolValue::I64(i) => write!(f, "int64({})", i), SymbolValue::I64(i) => write!(f, "int64({i})"),
SymbolValue::U32(i) => write!(f, "uint32({})", i), SymbolValue::U32(i) => write!(f, "uint32({i})"),
SymbolValue::U64(i) => write!(f, "uint64({})", i), SymbolValue::U64(i) => write!(f, "uint64({i})"),
SymbolValue::Str(s) => write!(f, "\"{}\"", s), SymbolValue::Str(s) => write!(f, "\"{s}\""),
SymbolValue::Double(d) => write!(f, "{}", d), SymbolValue::Double(d) => write!(f, "{d}"),
SymbolValue::Bool(b) => { SymbolValue::Bool(b) => {
if *b { if *b {
write!(f, "True") write!(f, "True")
@ -50,39 +228,82 @@ impl Display for SymbolValue {
} }
} }
SymbolValue::Tuple(t) => { SymbolValue::Tuple(t) => {
write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::<Vec<_>>().join(", ")) write!(f, "({})", t.iter().map(|v| format!("{v}")).collect::<Vec<_>>().join(", "))
} }
SymbolValue::OptionSome(v) => write!(f, "Some({})", v), SymbolValue::OptionSome(v) => write!(f, "Some({v})"),
SymbolValue::OptionNone => write!(f, "none"), SymbolValue::OptionNone => write!(f, "none"),
} }
} }
} }
impl TryFrom<SymbolValue> for u64 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`u64`], returning [`Err`] if the value is not
/// numeric or if the value cannot be converted into a `u64` without overflow.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::U64(v) => Ok(v),
_ => Err(()),
}
}
}
impl TryFrom<SymbolValue> for i128 {
type Error = ();
/// Tries to convert a [`SymbolValue`] into a [`i128`], returning [`Err`] if the value is not
/// numeric.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(i128::from(v)),
SymbolValue::I64(v) => Ok(i128::from(v)),
SymbolValue::U32(v) => Ok(i128::from(v)),
SymbolValue::U64(v) => Ok(i128::from(v)),
_ => Err(()),
}
}
}
pub trait StaticValue { pub trait StaticValue {
/// Returns a unique identifier for this value.
fn get_unique_identifier(&self) -> u64; fn get_unique_identifier(&self) -> u64;
fn get_const_obj<'ctx, 'a>( /// Returns the constant object represented by this unique identifier.
fn get_const_obj<'ctx>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx>; ) -> BasicValueEnum<'ctx>;
fn to_basic_value_enum<'ctx, 'a>( /// Converts this value to a LLVM [`BasicValueEnum`].
fn to_basic_value_enum<'ctx>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String>; ) -> Result<BasicValueEnum<'ctx>, String>;
fn get_field<'ctx, 'a>( /// Returns a field within this value.
fn get_field<'ctx>(
&self, &self,
name: StrRef, name: StrRef,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
) -> Option<ValueEnum<'ctx>>; ) -> Option<ValueEnum<'ctx>>;
/// Returns a single element of this tuple.
fn get_tuple_element<'ctx>(&self, index: u32) -> Option<ValueEnum<'ctx>>;
} }
#[derive(Clone)] #[derive(Clone)]
pub enum ValueEnum<'ctx> { pub enum ValueEnum<'ctx> {
/// [`ValueEnum`] representing a static value.
Static(Arc<dyn StaticValue + Send + Sync>), Static(Arc<dyn StaticValue + Send + Sync>),
/// [`ValueEnum`] representing a dynamic value.
Dynamic(BasicValueEnum<'ctx>), Dynamic(BasicValueEnum<'ctx>),
} }
@ -117,20 +338,22 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
} }
impl<'ctx> ValueEnum<'ctx> { impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>( pub fn to_basic_value_enum<'a>(
self, self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
match self { match self {
ValueEnum::Static(v) => v.to_basic_value_enum(ctx, generator), ValueEnum::Static(v) => v.to_basic_value_enum(ctx, generator, expected_ty),
ValueEnum::Dynamic(v) => Ok(v), ValueEnum::Dynamic(v) => Ok(v),
} }
} }
} }
pub trait SymbolResolver { pub trait SymbolResolver {
// get type of type variable identifier or top-level function type /// Get type of type variable identifier or top-level function type,
fn get_symbol_type( fn get_symbol_type(
&self, &self,
unifier: &mut Unifier, unifier: &mut Unifier,
@ -139,16 +362,17 @@ pub trait SymbolResolver {
str: StrRef, str: StrRef,
) -> Result<Type, String>; ) -> Result<Type, String>;
// get the top-level definition of identifiers /// Get the top-level definition of identifiers.
fn get_identifier_def(&self, str: StrRef) -> Result<DefinitionId, String>; fn get_identifier_def(&self, str: StrRef) -> Result<DefinitionId, HashSet<String>>;
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
str: StrRef, str: StrRef,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>>; ) -> Option<ValueEnum<'ctx>>;
fn get_default_param_value(&self, expr: &nac3parser::ast::Expr) -> Option<SymbolValue>; fn get_default_param_value(&self, expr: &Expr) -> Option<SymbolValue>;
fn get_string_id(&self, s: &str) -> i32; fn get_string_id(&self, s: &str) -> i32;
fn get_exception_id(&self, tyid: usize) -> usize; fn get_exception_id(&self, tyid: usize) -> usize;
@ -156,7 +380,7 @@ pub trait SymbolResolver {
&self, &self,
_unifier: &mut Unifier, _unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>], _top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore _primitives: &PrimitiveStore,
) -> Result<(), String> { ) -> Result<(), String> {
Ok(()) Ok(())
} }
@ -169,23 +393,23 @@ thread_local! {
"float".into(), "float".into(),
"bool".into(), "bool".into(),
"virtual".into(), "virtual".into(),
"list".into(),
"tuple".into(), "tuple".into(),
"str".into(), "str".into(),
"Exception".into(), "Exception".into(),
"uint32".into(), "uint32".into(),
"uint64".into(), "uint64".into(),
"Literal".into(),
]; ];
} }
// convert type annotation into type /// Converts a type annotation into a [Type].
pub fn parse_type_annotation<T>( pub fn parse_type_annotation<T>(
resolver: &dyn SymbolResolver, resolver: &dyn SymbolResolver,
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
expr: &Expr<T>, expr: &Expr<T>,
) -> Result<Type, String> { ) -> Result<Type, HashSet<String>> {
use nac3parser::ast::ExprKind::*; use nac3parser::ast::ExprKind::*;
let ids = IDENTIFIER_ID.with(|ids| *ids); let ids = IDENTIFIER_ID.with(|ids| *ids);
let int32_id = ids[0]; let int32_id = ids[0];
@ -193,12 +417,12 @@ pub fn parse_type_annotation<T>(
let float_id = ids[2]; let float_id = ids[2];
let bool_id = ids[3]; let bool_id = ids[3];
let virtual_id = ids[4]; let virtual_id = ids[4];
let list_id = ids[5]; let tuple_id = ids[5];
let tuple_id = ids[6]; let str_id = ids[6];
let str_id = ids[7]; let exn_id = ids[7];
let exn_id = ids[8]; let uint32_id = ids[8];
let uint32_id = ids[9]; let uint64_id = ids[9];
let uint64_id = ids[10]; let literal_id = ids[10];
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
if *id == int32_id { if *id == int32_id {
@ -219,39 +443,33 @@ pub fn parse_type_annotation<T>(
Ok(primitives.exception) Ok(primitives.exception)
} else { } else {
let obj_id = resolver.get_identifier_def(*id); let obj_id = resolver.get_identifier_def(*id);
match obj_id { if let Ok(obj_id) = obj_id {
Ok(obj_id) => { let def = top_level_defs[obj_id.0].read();
let def = top_level_defs[obj_id.0].read(); if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if !type_vars.is_empty() {
if !type_vars.is_empty() { return Err(HashSet::from([format!(
return Err(format!( "Unexpected number of type parameters: expected {} but got 0",
"Unexpected number of type parameters: expected {} but got 0", type_vars.len()
type_vars.len() )]));
));
}
let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect();
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: Default::default(),
}))
} else {
Err(format!("Cannot use function name as type at {}", loc))
} }
let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect();
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
} else {
Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
} }
Err(_) => { } else {
let ty = resolver let ty =
.get_symbol_type(unifier, top_level_defs, primitives, *id) resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
.map_err(|e| format!("Unknown type annotation at {}: {}", loc, e))?; |e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { )?;
Ok(ty) if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
} else { Ok(ty)
Err(format!("Unknown type annotation {} at {}", id, loc)) } else {
} Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
} }
} }
} }
@ -261,9 +479,6 @@ pub fn parse_type_annotation<T>(
if *id == virtual_id { if *id == virtual_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} else if *id == list_id {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
} else if *id == tuple_id { } else if *id == tuple_id {
if let Tuple { elts, .. } = &slice.node { if let Tuple { elts, .. } = &slice.node {
let ty = elts let ty = elts
@ -272,10 +487,33 @@ pub fn parse_type_annotation<T>(
parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt) parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty })) Ok(unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }))
} else { } else {
Err("Expected multiple elements for tuple".into()) Err(HashSet::from(["Expected multiple elements for tuple".into()]))
} }
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, elt)?;
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([format!(
"Expected literal in type argument for Literal at {}",
elt.location
)])),
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}
.into_iter()
.flatten()
.collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else { } else {
let types = if let Tuple { elts, .. } = &slice.node { let types = if let Tuple { elts, .. } = &slice.node {
elts.iter() elts.iter()
@ -291,13 +529,13 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() { if types.len() != type_vars.len() {
return Err(format!( return Err(HashSet::from([format!(
"Unexpected number of type parameters: expected {} but got {}", "Unexpected number of type parameters: expected {} but got {}",
type_vars.len(), type_vars.len(),
types.len() types.len()
)); )]));
} }
let mut subst = HashMap::new(); let mut subst = VarMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) { for (var, ty) in izip!(type_vars.iter(), types.iter()) {
let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) { let id = if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*var) {
*id *id
@ -319,7 +557,7 @@ pub fn parse_type_annotation<T>(
})); }));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else { } else {
Err("Cannot use function name as type".into()) Err(HashSet::from(["Cannot use function name as type".into()]))
} }
} }
}; };
@ -330,10 +568,13 @@ pub fn parse_type_annotation<T>(
if let Name { id, .. } = &value.node { if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier) subscript_name_handle(id, slice, unifier)
} else { } else {
Err(format!("unsupported type expression at {}", expr.location)) Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
} }
} }
_ => Err(format!("unsupported type expression at {}", expr.location)), Constant { value, .. } => SymbolValue::from_constant_inferred(value)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
.map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
} }
} }
@ -344,7 +585,7 @@ impl dyn SymbolResolver + Send + Sync {
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
expr: &Expr<T>, expr: &Expr<T>,
) -> Result<Type, String> { ) -> Result<Type, HashSet<String>> {
parse_type_annotation(self, top_level_defs, unifier, primitives, expr) parse_type_annotation(self, top_level_defs, unifier, primitives, expr)
} }
@ -357,13 +598,13 @@ impl dyn SymbolResolver + Send + Sync {
unifier.internal_stringify( unifier.internal_stringify(
ty, ty,
&mut |id| { &mut |id| {
if let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() { let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else {
name.to_string()
} else {
unreachable!("expected class definition") unreachable!("expected class definition")
} };
name.to_string()
}, },
&mut |id| format!("var{}", id), &mut |id| format!("typevar{id}"),
&mut None, &mut None,
) )
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -3,59 +3,70 @@ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
fmt::Debug, fmt::Debug,
iter::FromIterator, iter::FromIterator,
ops::{Deref, DerefMut},
sync::Arc, sync::Arc,
}; };
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier};
use crate::{
codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{type_inferencer::CodeLocation, typedef::CallId},
};
use inkwell::values::BasicValueEnum; use inkwell::values::BasicValueEnum;
use itertools::{izip, Itertools}; use itertools::Itertools;
use nac3parser::ast::{self, Location, Stmt, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
use nac3parser::ast::{self, Expr, Location, Stmt, StrRef};
use crate::{
codegen::{CodeGenContext, CodeGenerator},
symbol_resolver::{SymbolResolver, ValueEnum},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{
CallId, FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, TypeVarId, Unifier,
VarMap,
},
},
};
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)] #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Debug)]
pub struct DefinitionId(pub usize); pub struct DefinitionId(pub usize);
pub mod builtins; pub mod builtins;
pub mod composer; pub mod composer;
pub mod helper; pub mod helper;
pub mod numpy;
pub mod type_annotation; pub mod type_annotation;
use composer::*; use composer::*;
use type_annotation::*; use type_annotation::*;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
type GenCallCallback = Box< type GenCallCallback = dyn for<'ctx, 'a> Fn(
dyn for<'ctx, 'a> Fn( &mut CodeGenContext<'ctx, 'a>,
&mut CodeGenContext<'ctx, 'a>, Option<(Type, ValueEnum<'ctx>)>,
Option<(Type, ValueEnum<'ctx>)>, (&FunSignature, DefinitionId),
(&FunSignature, DefinitionId), Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
Vec<(Option<StrRef>, ValueEnum<'ctx>)>, &mut dyn CodeGenerator,
&mut dyn CodeGenerator, ) -> Result<Option<BasicValueEnum<'ctx>>, String>
) -> Result<Option<BasicValueEnum<'ctx>>, String> + Send
+ Send + Sync;
+ Sync,
>;
pub struct GenCall { pub struct GenCall {
fp: GenCallCallback, fp: Box<GenCallCallback>,
} }
impl GenCall { impl GenCall {
pub fn new(fp: GenCallCallback) -> GenCall { #[must_use]
pub fn new(fp: Box<GenCallCallback>) -> GenCall {
GenCall { fp } GenCall { fp }
} }
pub fn run<'ctx, 'a>( /// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`.
#[must_use]
pub fn create_dummy(reason: String) -> GenCall {
Self::new(Box::new(move |_, _, _, _, _| unreachable!("{reason}")))
}
pub fn run<'ctx>(
&self, &self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
@ -75,58 +86,85 @@ impl Debug for GenCall {
pub struct FunInstance { pub struct FunInstance {
pub body: Arc<Vec<Stmt<Option<Type>>>>, pub body: Arc<Vec<Stmt<Option<Type>>>>,
pub calls: Arc<HashMap<CodeLocation, CallId>>, pub calls: Arc<HashMap<CodeLocation, CallId>>,
pub subst: HashMap<u32, Type>, pub subst: VarMap,
pub unifier_id: usize, pub unifier_id: usize,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TopLevelDef { pub enum TopLevelDef {
Class { Class {
// name for error messages and symbols /// Name for error messages and symbols.
name: StrRef, name: StrRef,
// object ID used for TypeEnum /// Object ID used for [`TypeEnum`].
object_id: DefinitionId, object_id: DefinitionId,
/// type variables bounded to the class. /// type variables bounded to the class.
type_vars: Vec<Type>, type_vars: Vec<Type>,
// class fields /// Class fields.
// name, type, is mutable ///
/// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>, fields: Vec<(StrRef, Type, bool)>,
// class methods, pointing to the corresponding function definition. /// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>, methods: Vec<(StrRef, Type, DefinitionId)>,
// ancestor classes, including itself. /// Ancestor classes, including itself.
ancestors: Vec<TypeAnnotation>, ancestors: Vec<TypeAnnotation>,
// symbol resolver of the module defined the class, none if it is built-in type /// Symbol resolver of the module defined the class; [None] if it is built-in type.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
// constructor type /// Constructor type.
constructor: Option<Type>, constructor: Option<Type>,
// definition location /// Definition location.
loc: Option<Location>, loc: Option<Location>,
}, },
Function { Function {
// prefix for symbol, should be unique globally /// Prefix for symbol, should be unique globally.
name: String, name: String,
// simple name, the same as in method/function definition /// Simple name, the same as in method/function definition.
simple_name: StrRef, simple_name: StrRef,
// function signature. /// Function signature.
signature: Type, signature: Type,
// instantiated type variable IDs /// Instantiated type variable IDs.
var_id: Vec<u32>, var_id: Vec<TypeVarId>,
/// Function instance to symbol mapping /// Function instance to symbol mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending ///
/// order, including type variables associated with the class. /// * Key: String representation of type variable values, sorted by variable ID in ascending
/// Value: function symbol name. /// order, including type variables associated with the class.
/// * Value: Function symbol name.
instance_to_symbol: HashMap<String, String>, instance_to_symbol: HashMap<String, String>,
/// Function instances to annotated AST mapping /// Function instances to annotated AST mapping
/// Key: string representation of type variable values, sorted by variable ID in ascending ///
/// order, including type variables associated with the class. Excluding rigid type /// * Key: String representation of type variable values, sorted by variable ID in ascending
/// variables. /// order, including type variables associated with the class. Excluding rigid type
/// rigid type variables that would be substituted when the function is instantiated. /// variables.
///
/// Rigid type variables that would be substituted when the function is instantiated.
instance_to_stmt: HashMap<String, FunInstance>, instance_to_stmt: HashMap<String, FunInstance>,
// symbol resolver of the module defined the class /// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
// custom codegen callback /// Custom code generation callback.
codegen_callback: Option<Arc<GenCall>>, codegen_callback: Option<Arc<GenCall>>,
// definition location /// Definition location.
loc: Option<Location>,
},
Variable {
/// Qualified name of the global variable, should be unique globally.
name: String,
/// Simple name, the same as in method/function definition.
simple_name: StrRef,
/// Type of the global variable.
ty: Type,
/// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>, loc: Option<Location>,
}, },
} }

View File

@ -0,0 +1,86 @@
use itertools::Itertools;
use crate::{
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, TypeVarId, Unifier, VarMap},
},
};
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn make_ndarray_ty(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray;
}
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
if let Some(dtype) = dtype {
tvar_subst.insert(tvar_ids[0], dtype);
}
if let Some(ndims) = ndims {
tvar_subst.insert(tvar_ids[1], ndims);
}
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(TypeVarId, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarId, TypeVarId) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -1,13 +1,11 @@
--- ---
source: nac3core/src/toplevel/test.rs source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [18]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -1,15 +1,13 @@
--- ---
source: nac3core/src/toplevel/test.rs source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[var7]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"var7\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -1,15 +1,13 @@
--- ---
source: nac3core/src/toplevel/test.rs source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec expression: res_vec
--- ---
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [20]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [25]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
] ]

View File

@ -1,15 +1,13 @@
--- ---
source: nac3core/src/toplevel/test.rs source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[var6, var7]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"var6\", \"var7\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",
] ]

View File

@ -1,19 +1,17 @@
--- ---
source: nac3core/src/toplevel/test.rs source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [26]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [34]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
] ]

View File

@ -1,19 +1,24 @@
use std::{collections::HashMap, sync::Arc};
use indoc::indoc;
use parking_lot::Mutex;
use test_case::test_case;
use nac3parser::{
ast::{fold::Fold, FileName},
parser::parse_program,
};
use super::*;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::DefinitionId, toplevel::{helper::PrimDef, DefinitionId},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, Unifier}, typedef::{into_var_map, Type, Unifier},
}, },
}; };
use indoc::indoc;
use nac3parser::{ast::fold::Fold, parser::parse_program};
use parking_lot::Mutex;
use std::{collections::HashMap, sync::Arc};
use test_case::test_case;
use super::*;
struct ResolverInternal { struct ResolverInternal {
id_to_type: Mutex<HashMap<StrRef, Type>>, id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -36,7 +41,7 @@ struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value( fn get_default_param_value(
&self, &self,
_: &nac3parser::ast::Expr, _: &ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> { ) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!() unimplemented!()
} }
@ -52,20 +57,26 @@ impl SymbolResolver for Resolver {
.id_to_type .id_to_type
.lock() .lock()
.get(&str) .get(&str)
.cloned() .copied()
.ok_or_else(|| format!("cannot find symbol `{}`", str)) .ok_or_else(|| format!("cannot find symbol `{str}`"))
} }
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, 'a>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).cloned().ok_or_else(|| "Unknown identifier".to_string()) self.0
.id_to_def
.lock()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
} }
fn get_string_id(&self, _: &str) -> i32 { fn get_string_id(&self, _: &str) -> i32 {
@ -105,23 +116,41 @@ impl SymbolResolver for Resolver {
def __init__(self): def __init__(self):
self.c: int32 = 4 self.c: int32 = 4
self.a: bool = True self.a: bool = True
"} "},
]; ];
"register" "register"
)] )]
fn test_simple_register(source: Vec<&str>) { fn test_simple_register(source: Vec<&str>) {
let mut composer: TopLevelComposer = Default::default(); let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
composer.register_top_level(ast, None, "".into()).unwrap(); composer.register_top_level(ast, None, "", false).unwrap();
} }
} }
#[test_case( #[test_case(
vec![ indoc! {"
class A:
def foo(self):
pass
a = A()
"};
"register"
)]
fn test_simple_register_without_constructor(source: &str) {
let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let ast = parse_program(source, FileName::default()).unwrap();
let ast = ast[0].clone();
composer.register_top_level(ast, None, "", true).unwrap();
}
#[test_case(
&[
indoc! {" indoc! {"
def fun(a: int32) -> int32: def fun(a: int32) -> int32:
return a return a
@ -135,35 +164,36 @@ fn test_simple_register(source: Vec<&str>) {
return 3 return 3
"}, "},
], ],
vec![ &[
"fn[[a:0], 0]", "fn[[a:0], 0]",
"fn[[a:2], 4]", "fn[[a:2], 4]",
"fn[[b:1], 0]", "fn[[b:1], 0]",
], ],
vec![ &[
"fun", "fun",
"foo", "foo",
"f" "f"
]; ];
"function compose" "function compose"
)] )]
fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&str>) { fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
let mut composer: TopLevelComposer = Default::default(); let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = Arc::new(ResolverInternal { let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Default::default(), id_to_def: Mutex::default(),
id_to_type: Default::default(), id_to_type: Mutex::default(),
class_names: Default::default(), class_names: Mutex::default(),
}); });
let resolver = let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id, ty) = let (id, def_id, ty) =
composer.register_top_level(ast, Some(resolver.clone()), "".into()).unwrap(); composer.register_top_level(ast, Some(resolver.clone()), "", false).unwrap();
internal_resolver.add_id_def(id, def_id); internal_resolver.add_id_def(id, def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty); internal_resolver.add_id_type(id, ty);
@ -189,7 +219,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
} }
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(): class A():
a: int32 a: int32
@ -222,11 +252,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"simple class compose" "simple class compose"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class Generic_A(Generic[V], B): class Generic_A(Generic[V], B):
a: int64 a: int64
@ -244,11 +274,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"generic class" "generic class"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]: def foo(a: list[int32], b: tuple[T, float]) -> A[B, bool]:
pass pass
@ -273,11 +303,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"list tuple generic" "list tuple generic"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T, V]): class A(Generic[T, V]):
a: A[float, bool] a: A[float, bool]
@ -298,11 +328,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"self1" "self1"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
a: int32 a: int32
@ -332,11 +362,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec![]; &[];
"inheritance_override" "inheritance_override"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
def __init__(self): def __init__(self):
@ -345,11 +375,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["application of type vars to generic class is not currently supported (at unknown: line 4 column 24)"]; &["application of type vars to generic class is not currently supported (at unknown:4:24)"];
"err no type var in generic app" "err no type var in generic app"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(B): class A(B):
def __init__(self): def __init__(self):
@ -361,11 +391,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["cyclic inheritance detected"]; &["cyclic inheritance detected"];
"cyclic1" "cyclic1"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(B[bool, int64]): class A(B[bool, int64]):
def __init__(self): def __init__(self):
@ -382,30 +412,30 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"}, "},
], ],
vec!["cyclic inheritance detected"]; &["cyclic inheritance detected"];
"cyclic2" "cyclic2"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A: class A:
pass pass
"} "}
], ],
vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"]; &["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
"simple pass in class" "simple pass in class"
)] )]
#[test_case( #[test_case(
vec![indoc! {" &[indoc! {"
class A: class A:
def __init__(): def __init__():
pass pass
"}], "}],
vec!["__init__ method must have a `self` parameter (at unknown: line 2 column 5)"]; &["__init__ method must have a `self` parameter (at unknown:2:5)"];
"err no self_1" "err no self_1"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(B, Generic[T], C): class A(B, Generic[T], C):
def __init__(self): def __init__(self):
@ -423,11 +453,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
"} "}
], ],
vec!["a class definition can only have at most one base class declaration and one generic declaration (at unknown: line 1 column 24)"]; &["a class definition can only have at most one base class declaration and one generic declaration (at unknown:1:24)"];
"err multiple inheritance" "err multiple inheritance"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
a: int32 a: int32
@ -448,11 +478,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["method fun has same name as ancestors' method, but incompatible type"]; &["method fun has same name as ancestors' method, but incompatible type"];
"err_incompatible_inheritance_method" "err_incompatible_inheritance_method"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A(Generic[T]): class A(Generic[T]):
a: int32 a: int32
@ -474,11 +504,11 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["field `a` has already declared in the ancestor classes"]; &["field `a` has already declared in the ancestor classes"];
"err_incompatible_inheritance_field" "err_incompatible_inheritance_field"
)] )]
#[test_case( #[test_case(
vec![ &[
indoc! {" indoc! {"
class A: class A:
def __init__(self): def __init__(self):
@ -491,12 +521,13 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
pass pass
"} "}
], ],
vec!["duplicate definition of class `A` (at unknown: line 1 column 1)"]; &["duplicate definition of class `A` (at unknown:1:1)"];
"class same name" "class same name"
)] )]
fn test_analyze(source: Vec<&str>, res: Vec<&str>) { fn test_analyze(source: &[&str], res: &[&str]) {
let print = false; let print = false;
let mut composer: TopLevelComposer = Default::default(); let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![
@ -511,15 +542,15 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id, ty) = { let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "".into()) { match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
Ok(x) => x, Ok(x) => x,
Err(msg) => { Err(msg) => {
if print { if print {
println!("{}", msg); println!("{msg}");
} else { } else {
assert_eq!(res[0], msg); assert_eq!(res[0], msg);
} }
@ -535,9 +566,9 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
if let Err(msg) = composer.start_analysis(false) { if let Err(msg) = composer.start_analysis(false) {
if print { if print {
println!("{}", msg); println!("{}", msg.iter().sorted().join("\n----------\n"));
} else { } else {
assert_eq!(res[0], msg); assert_eq!(res[0], msg.iter().next().unwrap());
} }
} else { } else {
// skip 5 to skip primitives // skip 5 to skip primitives
@ -565,7 +596,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return fib(n - 1) return fib(n - 1)
"} "}
], ],
vec![]; &[];
"simple function" "simple function"
)] )]
#[test_case( #[test_case(
@ -598,7 +629,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return a.fun() + 2 return a.fun() + 2
"} "}
], ],
vec![]; &[];
"simple class body" "simple class body"
)] )]
#[test_case( #[test_case(
@ -623,7 +654,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return [a, b] return [a, b]
"} "}
], ],
vec![]; &[];
"type var fun" "type var fun"
)] )]
#[test_case( #[test_case(
@ -644,7 +675,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
return ret if self.b else self.fun(self.a) return ret if self.b else self.fun(self.a)
"} "}
], ],
vec![]; &[];
"type var class" "type var class"
)] )]
#[test_case( #[test_case(
@ -668,12 +699,13 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
self.b = True self.b = True
"} "}
], ],
vec![]; &[];
"no_init_inst_check" "no_init_inst_check"
)] )]
fn test_inference(source: Vec<&str>, res: Vec<&str>) { fn test_inference(source: Vec<&str>, res: &[&str]) {
let print = true; let print = true;
let mut composer: TopLevelComposer = Default::default(); let mut composer =
TopLevelComposer::new(Vec::new(), Vec::new(), ComposerConfig::default(), 64).0;
let internal_resolver = make_internal_resolver_with_tvar( let internal_resolver = make_internal_resolver_with_tvar(
vec![ vec![
@ -695,15 +727,15 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
for s in source { for s in source {
let ast = parse_program(s, Default::default()).unwrap(); let ast = parse_program(s, FileName::default()).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id, ty) = { let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "".into()) { match composer.register_top_level(ast, Some(resolver.clone()), "", false) {
Ok(x) => x, Ok(x) => x,
Err(msg) => { Err(msg) => {
if print { if print {
println!("{}", msg); println!("{msg}");
} else { } else {
assert_eq!(res[0], msg); assert_eq!(res[0], msg);
} }
@ -719,16 +751,14 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
if let Err(msg) = composer.start_analysis(true) { if let Err(msg) = composer.start_analysis(true) {
if print { if print {
println!("{}", msg); println!("{}", msg.iter().sorted().join("\n----------\n"));
} else { } else {
assert_eq!(res[0], msg); assert_eq!(res[0], msg.iter().next().unwrap());
} }
} else { } else {
// skip 5 to skip primitives // skip 5 to skip primitives
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier }; let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier };
for (_i, (def, _)) in for (def, _) in composer.definition_ast_list.iter().skip(composer.builtin_num) {
composer.definition_ast_list.iter().skip(composer.builtin_num).enumerate()
{
let def = &*def.read(); let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def { if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
@ -737,7 +767,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
name, name,
instance_to_stmt.len() instance_to_stmt.len()
); );
for inst in instance_to_stmt.iter() { for inst in instance_to_stmt {
let ast = &inst.1.body; let ast = &inst.1.body;
for b in ast.iter() { for b in ast.iter() {
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap()); println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
@ -755,22 +785,29 @@ fn make_internal_resolver_with_tvar(
unifier: &mut Unifier, unifier: &mut Unifier,
print: bool, print: bool,
) -> Arc<ResolverInternal> { ) -> Arc<ResolverInternal> {
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let res: Arc<ResolverInternal> = ResolverInternal { let res: Arc<ResolverInternal> = ResolverInternal {
id_to_def: Default::default(), id_to_def: Mutex::new(HashMap::from([("list".into(), PrimDef::List.id())])),
id_to_type: tvars id_to_type: tvars
.into_iter() .into_iter()
.map(|(name, range)| { .map(|(name, range)| {
(name, { (name, {
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice(), None, None); let tvar = unifier.get_fresh_var_with_range(range.as_slice(), None, None);
if print { if print {
println!("{}: {:?}, tvar{}", name, ty, id); println!("{}: {:?}, typevar{}", name, tvar.ty, tvar.id);
} }
ty tvar.ty
}) })
}) })
.collect::<HashMap<_, _>>() .collect::<HashMap<_, _>>()
.into(), .into(),
class_names: Default::default(), class_names: Mutex::new(HashMap::from([("list".into(), list)])),
} }
.into(); .into();
if print { if print {
@ -790,8 +827,8 @@ impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
Ok(if let Some(ty) = user { Ok(if let Some(ty) = user {
self.unifier.internal_stringify( self.unifier.internal_stringify(
ty, ty,
&mut |id| format!("class{}", id.to_string()), &mut |id| format!("class{id}"),
&mut |id| format!("tvar{}", id.to_string()), &mut |id| format!("typevar{id}"),
&mut None, &mut None,
) )
} else { } else {

View File

@ -1,4 +1,13 @@
use strum::IntoEnumIterator;
use nac3parser::ast::Constant;
use super::*; use super::*;
use crate::{
symbol_resolver::SymbolValue,
toplevel::helper::{PrimDef, PrimDefDetails},
typecheck::typedef::VarMap,
};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum TypeAnnotation { pub enum TypeAnnotation {
@ -12,7 +21,8 @@ pub enum TypeAnnotation {
// can only be CustomClassKind // can only be CustomClassKind
Virtual(Box<TypeAnnotation>), Virtual(Box<TypeAnnotation>),
TypeVar(Type), TypeVar(Type),
List(Box<TypeAnnotation>), /// A `Literal` allowing a subset of literals.
Literal(Vec<Constant>),
Tuple(Vec<TypeAnnotation>), Tuple(Vec<TypeAnnotation>),
} }
@ -22,52 +32,57 @@ impl TypeAnnotation {
match self { match self {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => { CustomClass { id, params } => {
let class_name = match unifier.top_level { let class_name = if let Some(ref top) = unifier.top_level {
Some(ref top) => { if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() {
if let TopLevelDef::Class { name, .. } = (*name).into()
&*top.definitions.read()[id.0].read() } else {
{ unreachable!()
(*name).into()
} else {
unreachable!()
}
} }
None => format!("class_def_{}", id.0), } else {
format!("class_def_{}", id.0)
}; };
format!( format!("{}{}", class_name, {
"{}{}", let param_list =
class_name, params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
{ if param_list.is_empty() {
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "); String::new()
if param_list.is_empty() { } else {
"".into() format!("[{param_list}]")
} else {
format!("[{}]", param_list)
}
} }
) })
}
Literal(values) => {
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
} }
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => { Tuple(types) => {
format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")) format!(
"tuple[{}]",
types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")
)
} }
} }
} }
} }
pub fn parse_ast_to_type_annotation_kinds<T>( /// Parses an AST expression `expr` into a [`TypeAnnotation`].
///
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
/// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>, S>,
) -> Result<TypeAnnotation, String> { ) -> Result<TypeAnnotation, HashSet<String>> {
let name_handle = |id: &StrRef, let name_handle = |id: &StrRef,
unifier: &mut Unifier, unifier: &mut Unifier,
locked: HashMap<DefinitionId, Vec<Type>>| { locked: HashMap<DefinitionId, Vec<Type>, S>| {
if id == &"int32".into() { if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32)) Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() { } else if id == &"int64".into() {
@ -83,7 +98,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} else if id == &"str".into() { } else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str)) Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() { } else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: DefinitionId(7), params: Default::default() }) Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) { } else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = { let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read(); let def_read = top_level_defs[obj_id.0].try_read();
@ -91,10 +106,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if let TopLevelDef::Class { type_vars, .. } = &*def_read { if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone() type_vars.clone()
} else { } else {
return Err(format!( return Err(HashSet::from([format!(
"function cannot be used as a type (at {})", "function cannot be used as a type (at {})",
expr.location expr.location
)); )]));
} }
} else { } else {
locked.get(&obj_id).unwrap().clone() locked.get(&obj_id).unwrap().clone()
@ -102,23 +117,29 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
// check param number here // check param number here
if !type_vars.is_empty() { if !type_vars.is_empty() {
return Err(format!( return Err(HashSet::from([format!(
"expect {} type variable parameter but got 0 (at {})", "expect {} type variable parameter but got 0 (at {})",
type_vars.len(), type_vars.len(),
expr.location, expr.location,
)); )]));
} }
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() { if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).0; let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty;
unifier.unify(var, ty).unwrap(); unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty)) Ok(TypeAnnotation::TypeVar(ty))
} else { } else {
Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location)) Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
} }
} else { } else {
Err(format!("`{}` is not a valid type annotation (at {})", id, expr.location)) Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
} }
}; };
@ -126,20 +147,22 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|id: &StrRef, |id: &StrRef,
slice: &ast::Expr<T>, slice: &ast::Expr<T>,
unifier: &mut Unifier, unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>>| { mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
if vec!["virtual".into(), "Generic".into(), "list".into(), "tuple".into()].contains(id) if ["virtual".into(), "Generic".into(), "tuple".into(), "Option".into()].contains(id) {
{ return Err(HashSet::from([format!(
return Err(format!("keywords cannot be class name (at {})", expr.location)); "keywords cannot be class name (at {})",
expr.location
)]));
} }
let obj_id = resolver.get_identifier_def(*id)?; let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = { let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read(); let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read { if let Some(def_read) = def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read { let TopLevelDef::Class { type_vars, .. } = &*def_read else {
type_vars.clone()
} else {
unreachable!("must be class here") unreachable!("must be class here")
} };
type_vars.clone()
} else { } else {
locked.get(&obj_id).unwrap().clone() locked.get(&obj_id).unwrap().clone()
} }
@ -152,12 +175,12 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
vec![slice] vec![slice]
}; };
if type_vars.len() != params_ast.len() { if type_vars.len() != params_ast.len() {
return Err(format!( return Err(HashSet::from([format!(
"expect {} type parameters but got {} (at {})", "expect {} type parameters but got {} (at {})",
type_vars.len(), type_vars.len(),
params_ast.len(), params_ast.len(),
params_ast[0].location, params_ast[0].location,
)); )]));
} }
let result = params_ast let result = params_ast
.iter() .iter()
@ -181,15 +204,17 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if no_type_var { if no_type_var {
result result
} else { } else {
return Err(format!( return Err(HashSet::from([
"application of type vars to generic class \ format!(
is not currently supported (at {})", "application of type vars to generic class is not currently supported (at {})",
params_ast[0].location params_ast[0].location
)); ),
]));
} }
}; };
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
}; };
match &expr.node { match &expr.node {
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked), ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
// virtual // virtual
@ -212,23 +237,6 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
Ok(TypeAnnotation::Virtual(def.into())) Ok(TypeAnnotation::Virtual(def.into()))
} }
// list
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"list".into())
} =>
{
let def_ann = parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
slice.as_ref(),
locked,
)?;
Ok(TypeAnnotation::List(def_ann.into()))
}
// option // option
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
@ -281,16 +289,70 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
Ok(TypeAnnotation::Tuple(type_annotations)) Ok(TypeAnnotation::Tuple(type_annotations))
} }
// Literal
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} =>
{
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
} else {
std::slice::from_ref(slice.as_ref())
}
};
let type_annotations = tup_elts
.iter()
.map(|e| match &e.node {
ast::ExprKind::Constant { value, .. } => {
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flat_map(|type_ann| match type_ann {
TypeAnnotation::Literal(values) => values,
_ => unreachable!(),
})
.collect_vec();
if type_annotations.len() == 1 {
Ok(TypeAnnotation::Literal(type_annotations))
} else {
Err(HashSet::from([format!(
"multiple literal bounds are currently unsupported (at {})",
value.location
)]))
}
}
// custom class // custom class
ast::ExprKind::Subscript { value, slice, .. } => { ast::ExprKind::Subscript { value, slice, .. } => {
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked) class_name_handle(id, slice, unifier, locked)
} else { } else {
Err(format!("unsupported expression type for class name (at {})", value.location)) Err(HashSet::from([format!(
"unsupported expression type for class name (at {})",
value.location
)]))
} }
} }
_ => Err(format!("unsupported expression for type annotation (at {})", expr.location)), ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
_ => Err(HashSet::from([format!(
"unsupported expression for type annotation (at {})",
expr.location
)])),
} }
} }
@ -302,41 +364,78 @@ pub fn get_type_from_type_annotation_kinds(
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>> subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, String> { ) -> Result<Type, HashSet<String>> {
match ann { match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
let def_read = top_level_defs[obj_id.0].read(); let def_read = top_level_defs[obj_id.0].read();
let class_def: &TopLevelDef = def_read.deref(); let class_def: &TopLevelDef = &def_read;
if let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def { let TopLevelDef::Class { fields, methods, type_vars, .. } = class_def else {
if type_vars.len() != params.len() { unreachable!("should be class def here")
Err(format!( };
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
))
} else {
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list
)
})
.collect::<Result<Vec<_>, _>>()?;
let subst = { if type_vars.len() != params.len() {
// check for compatible range return Err(HashSet::from([format!(
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check "unexpected number of type parameters: expected {} but got {}",
let mut result: HashMap<u32, Type> = HashMap::new(); type_vars.len(),
for (tvar, p) in type_vars.iter().zip(param_ty) { params.len()
if let TypeEnum::TVar { id, range, fields: None, name, loc } = )]));
unifier.get_ty(*tvar).as_ref() }
{
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
})
.collect::<Result<Vec<_>, _>>()?;
let ty = if let Some(prim_def) = PrimDef::iter().find(|prim| prim.id() == *obj_id) {
// Primitive TopLevelDefs do not contain all fields that are present in their Type
// counterparts, so directly perform subst on the Type instead.
let PrimDefDetails::PrimClass { get_ty_fn, .. } = prim_def.details() else {
unreachable!()
};
let base_ty = get_ty_fn(primitives);
let params =
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty_immutable(base_ty) {
params.clone()
} else {
unreachable!()
};
unifier
.subst(
get_ty_fn(primitives),
&params
.iter()
.zip(param_ty)
.map(|(obj_tv, param)| (*obj_tv.0, param))
.collect(),
)
.unwrap_or(base_ty)
} else {
let subst = {
// check for compatible range
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || { p == *tvar || {
@ -345,85 +444,119 @@ pub fn get_type_from_type_annotation_kinds(
*name, *name,
*loc, *loc,
); );
unifier.unify(temp.0, p).is_ok() unifier.unify(temp.ty, p).is_ok()
} }
}; };
if ok { if ok {
result.insert(*id, p); result.insert(*id, p);
} else { } else {
return Err(format!( return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}", "cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify( unifier.internal_stringify(
p, p,
&mut |id| format!("class{}", id), &mut |id| format!("class{id}"),
&mut |id| format!("tvar{}", id), &mut |id| format!("typevar{id}"),
&mut None &mut None
), ),
*id *id
)); )]));
} }
} else {
unreachable!("must be generic type var")
} }
TypeEnum::TVar {
id, range, name, loc, is_const_generic: true, ..
} => {
let ty = range[0];
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp =
unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.ty, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
_ => unreachable!("must be generic type var"),
} }
result
};
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
subst_list.as_mut().map(|wl| wl.push(ty));
} }
Ok(ty) result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
// methods are immutable
(*name, (subst_ty, false))
})
.collect::<HashMap<_, _>>();
tobj_fields.extend(fields.iter().map(|(name, ty, mutability)| {
let subst_ty = unifier.subst(*ty, &subst).unwrap_or(*ty);
(*name, (subst_ty, *mutability))
}));
let need_subst = !subst.is_empty();
let ty = unifier.add_ty(TypeEnum::TObj {
obj_id: *obj_id,
fields: tobj_fields,
params: subst,
});
if need_subst {
if let Some(wl) = subst_list.as_mut() {
wl.push(ty);
}
} }
} else {
unreachable!("should be class def here") ty
} };
Ok(ty)
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Literal(values) => {
let values = values
.iter()
.map(SymbolValue::from_constant_inferred)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?;
let var = unifier.get_fresh_literal(values, None);
Ok(var)
}
TypeAnnotation::Virtual(ty) => { TypeAnnotation::Virtual(ty) => {
let ty = get_type_from_type_annotation_kinds( let ty = get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
primitives, primitives,
ty.as_ref(), ty.as_ref(),
subst_list subst_list,
)?; )?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} }
TypeAnnotation::List(ty) => {
let ty = get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
ty.as_ref(),
subst_list
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
}
TypeAnnotation::Tuple(tys) => { TypeAnnotation::Tuple(tys) => {
let tys = tys let tys = tys
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds(top_level_defs, unifier, primitives, x, subst_list) get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
primitives,
x,
subst_list,
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys })) Ok(unifier.add_ty(TypeEnum::TTuple { ty: tys, is_vararg_ctx: false }))
} }
} }
} }
@ -436,10 +569,11 @@ pub fn get_type_from_type_annotation_kinds(
/// the type of `self` should be similar to `A[T, V]`, where `T`, `V` /// the type of `self` should be similar to `A[T, V]`, where `T`, `V`
/// considered to be type variables associated with the class \ /// considered to be type variables associated with the class \
/// \ /// \
/// But note that here we do not make a duplication of `T`, `V`, we direclty /// But note that here we do not make a duplication of `T`, `V`, we directly
/// use them as they are in the TopLevelDef::Class since those in the /// use them as they are in the [`TopLevelDef::Class`] since those in the
/// TopLevelDef::Class.type_vars will be substitute later when seeing applications/instantiations /// `TopLevelDef::Class.type_vars` will be substitute later when seeing applications/instantiations
/// the Type of their fields and methods will also be subst when application/instantiation /// the Type of their fields and methods will also be subst when application/instantiation
#[must_use]
pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation { pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) -> TypeAnnotation {
TypeAnnotation::CustomClass { TypeAnnotation::CustomClass {
id: object_id, id: object_id,
@ -450,27 +584,25 @@ pub fn make_self_type_annotation(type_vars: &[Type], object_id: DefinitionId) ->
/// get all the occurences of type vars contained in a type annotation /// get all the occurences of type vars contained in a type annotation
/// e.g. `A[int, B[T], V, virtual[C[G]]]` => [T, V, G] /// e.g. `A[int, B[T], V, virtual[C[G]]]` => [T, V, G]
/// this function will not make a duplicate of type var /// this function will not make a duplicate of type var
#[must_use]
pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> { pub fn get_type_var_contained_in_type_annotation(ann: &TypeAnnotation) -> Vec<TypeAnnotation> {
let mut result: Vec<TypeAnnotation> = Vec::new(); let mut result: Vec<TypeAnnotation> = Vec::new();
match ann { match ann {
TypeAnnotation::TypeVar(..) => result.push(ann.clone()), TypeAnnotation::TypeVar(..) => result.push(ann.clone()),
TypeAnnotation::Virtual(ann) => { TypeAnnotation::Virtual(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref())) result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()));
} }
TypeAnnotation::CustomClass { params, .. } => { TypeAnnotation::CustomClass { params, .. } => {
for p in params { for p in params {
result.extend(get_type_var_contained_in_type_annotation(p)); result.extend(get_type_var_contained_in_type_annotation(p));
} }
} }
TypeAnnotation::List(ann) => {
result.extend(get_type_var_contained_in_type_annotation(ann.as_ref()))
}
TypeAnnotation::Tuple(anns) => { TypeAnnotation::Tuple(anns) => {
for a in anns { for a in anns {
result.extend(get_type_var_contained_in_type_annotation(a)); result.extend(get_type_var_contained_in_type_annotation(a));
} }
} }
TypeAnnotation::Primitive(..) => {} TypeAnnotation::Primitive(..) | TypeAnnotation::Literal { .. } => {}
} }
result result
} }
@ -485,21 +617,20 @@ pub fn check_overload_type_annotation_compatible(
(TypeAnnotation::Primitive(a), TypeAnnotation::Primitive(b)) => a == b, (TypeAnnotation::Primitive(a), TypeAnnotation::Primitive(b)) => a == b,
(TypeAnnotation::TypeVar(a), TypeAnnotation::TypeVar(b)) => { (TypeAnnotation::TypeVar(a), TypeAnnotation::TypeVar(b)) => {
let a = unifier.get_ty(*a); let a = unifier.get_ty(*a);
let a = a.deref(); let a = &*a;
let b = unifier.get_ty(*b); let b = unifier.get_ty(*b);
let b = b.deref(); let b = &*b;
if let ( let (
TypeEnum::TVar { id: a, fields: None, .. }, TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. }, TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b) ) = (a, b)
{ else {
a == b
} else {
unreachable!("must be type var") unreachable!("must be type var")
} };
a == b
} }
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) (TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b)) => {
| (TypeAnnotation::List(a), TypeAnnotation::List(b)) => {
check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier) check_overload_type_annotation_compatible(a.as_ref(), b.as_ref(), unifier)
} }

View File

@ -1,515 +0,0 @@
use slab::Slab;
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use nac3parser::ast::{Location, StrRef};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LifetimeKind {
Static,
NonLocal,
Unknown,
PreciseLocal,
ImpreciseLocal,
}
impl std::ops::BitAnd for LifetimeKind {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
use LifetimeKind::*;
match (self, rhs) {
(x, y) if x == y => x,
(PreciseLocal, ImpreciseLocal) | (ImpreciseLocal, PreciseLocal) => ImpreciseLocal,
(Static, NonLocal) | (NonLocal, Static) => NonLocal,
_ => Unknown,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct LifetimeId(usize);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BasicBlockId(usize);
#[derive(Debug, Clone)]
pub enum LifetimeIR {
VarAssign { var: StrRef, lifetime: LifetimeId },
VarAccess { var: StrRef },
FieldAssign { obj: LifetimeId, field: StrRef, new: LifetimeId, is_init: bool },
FieldAccess { obj: LifetimeId, field: StrRef },
CreateLifetime { kind: LifetimeKind },
PassedToFunc { param_lifetimes: Vec<LifetimeId> },
UnifyLifetimes { lifetimes: Vec<LifetimeId> },
Branch { targets: Vec<BasicBlockId> },
Return { val: Option<LifetimeId> },
}
pub struct LifetimeIRBuilder {
irs: Vec<Option<(LifetimeIR, Location)>>,
basic_blocks: Vec<Vec<usize>>,
current_block: BasicBlockId,
}
impl LifetimeIRBuilder {
pub fn new() -> Self {
LifetimeIRBuilder {
irs: vec![None],
basic_blocks: vec![vec![]],
current_block: BasicBlockId(0),
}
}
pub fn print_ir(&self) -> String {
let mut lines = vec![];
for (i, bb) in self.basic_blocks.iter().enumerate() {
if bb.is_empty() {
continue;
}
lines.push(format!("{}:", i));
for ir in bb.iter() {
if let Some((inst, loc)) = &self.irs[*ir] {
lines.push(format!(" {}: {:?} ({})", *ir, inst, loc));
}
}
}
lines.join("\n")
}
pub fn append_ir(&mut self, inst: LifetimeIR, loc: Location) -> LifetimeId {
let id = self.irs.len();
self.irs.push(Some((inst, loc)));
self.basic_blocks[self.current_block.0].push(id);
LifetimeId(id)
}
pub fn append_block(&mut self) -> BasicBlockId {
let id = self.basic_blocks.len();
self.basic_blocks.push(vec![]);
BasicBlockId(id)
}
pub fn get_current_block(&self) -> BasicBlockId {
self.current_block
}
pub fn position_at_end(&mut self, id: BasicBlockId) {
self.current_block = id;
}
pub fn is_terminated(&self, id: BasicBlockId) -> bool {
let bb = &self.basic_blocks[id.0];
if bb.is_empty() {
false
} else {
matches!(
self.irs[*bb.last().unwrap()],
Some((LifetimeIR::Return { .. }, _)) | Some((LifetimeIR::Branch { .. }, _))
)
}
}
pub fn remove_empty_bb(&mut self) {
let mut destination_mapping = HashMap::new();
let basic_blocks = &mut self.basic_blocks;
let irs = &mut self.irs;
for (i, bb) in basic_blocks.iter_mut().enumerate() {
bb.retain(|&id| irs[id].is_some());
if bb.len() == 1 {
let id = bb.pop().unwrap();
let ir = irs[id].take().unwrap();
match ir.0 {
LifetimeIR::Branch { targets } => {
destination_mapping.insert(i, targets);
}
_ => (),
}
}
}
let mut buffer = HashSet::new();
for bb in basic_blocks.iter_mut() {
if bb.is_empty() {
continue;
}
if let LifetimeIR::Branch { targets } =
&mut irs[*bb.last().unwrap()].as_mut().unwrap().0
{
buffer.clear();
let mut updated = false;
for target in targets.iter() {
if let Some(dest) = destination_mapping.get(&target.0) {
buffer.extend(dest.iter().cloned());
updated = true;
} else {
buffer.insert(*target);
}
}
if updated {
targets.clear();
targets.extend(buffer.iter().cloned());
}
}
}
}
pub fn analyze(&self) -> Result<(), String> {
let mut analyzers = HashMap::new();
analyzers.insert(0, (0, true, LifetimeAnalyzer::new()));
let mut worklist = vec![0];
while let Some(bb) = worklist.pop() {
let (counter, updated, analyzer) = analyzers.get_mut(&bb).unwrap();
*counter += 1;
if *counter > 100 {
return Err(format!("infinite loop detected at basic block {}", bb));
}
*updated = false;
let mut analyzer = analyzer.clone();
let block = &self.basic_blocks[bb];
let ir_iter = block.iter().filter_map(|&id| {
self.irs[id].as_ref().map(|(ir, loc)| (LifetimeId(id), ir, *loc))
});
if let Some(branch) = analyzer.analyze_basic_block(ir_iter)? {
for &target in branch.iter() {
if let Some((_, updated, successor)) = analyzers.get_mut(&target.0) {
if successor.merge(&analyzer) && !*updated {
// changed
worklist.push(target.0);
*updated = true;
}
} else {
analyzers.insert(target.0, (0, true, analyzer.clone()));
worklist.push(target.0);
}
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct LifetimeStore {
kind: LifetimeKind,
fields: HashMap<StrRef, LifetimeId>,
lifetimes: HashSet<LifetimeId>,
}
#[derive(Debug, Clone)]
pub struct LifetimeAnalyzer<'a> {
lifetime_to_id: HashMap<LifetimeId, usize>,
lifetime_stores: Slab<Cow<'a, LifetimeStore>>,
variable_assignment: HashMap<StrRef, LifetimeId>,
}
impl<'a> LifetimeAnalyzer<'a> {
pub fn new() -> Self {
let mut zelf = LifetimeAnalyzer {
lifetime_to_id: HashMap::new(),
lifetime_stores: Default::default(),
variable_assignment: HashMap::new(),
};
zelf.add_lifetime(LifetimeId(0), LifetimeKind::Unknown);
zelf
}
pub fn merge(&mut self, other: &LifetimeAnalyzer) -> bool {
let mut to_be_merged = other.lifetime_to_id.keys().cloned().collect::<Vec<_>>();
let mut updated = false;
let mut lifetime_merge_list = vec![];
for (&var_name, &lifetime) in other.variable_assignment.iter() {
if let Some(&our_lifetime) = self.variable_assignment.get(&var_name) {
if our_lifetime != lifetime {
lifetime_merge_list.push((our_lifetime, lifetime));
}
} else {
self.variable_assignment.insert(var_name, lifetime);
updated = true;
}
}
while let Some(lifetime) = to_be_merged.pop() {
let other_store_id = *other.lifetime_to_id.get(&lifetime).unwrap();
if let Some(&self_store_id) = self.lifetime_to_id.get(&lifetime) {
let self_store = self.lifetime_stores.get_mut(self_store_id).unwrap();
let other_store = other.lifetime_stores.get(other_store_id).unwrap();
let self_store = self_store.to_mut();
// merge them
for (&field, &other_lifetime) in other_store.fields.iter() {
if let Some(&self_lifetime) = self_store.fields.get(&field) {
if self_lifetime != other_lifetime {
lifetime_merge_list.push((self_lifetime, other_lifetime));
}
} else {
self_store.fields.insert(field, other_lifetime);
updated = true;
}
}
let zelf_lifetimes = &mut self_store.lifetimes;
for &other_lifetime in other_store.lifetimes.iter() {
if zelf_lifetimes.insert(other_lifetime) {
lifetime_merge_list.push((lifetime, other_lifetime));
}
}
let result_kind = self_store.kind & other_store.kind;
if self_store.kind != result_kind {
self_store.kind = result_kind;
}
} else {
let store = other.lifetime_stores.get(other_store_id).unwrap().as_ref().clone();
let store = self.lifetime_stores.insert(Cow::Owned(store));
self.lifetime_to_id.insert(lifetime, store);
updated = true;
}
}
for (a, b) in lifetime_merge_list.into_iter() {
self.unify(a, b);
}
updated
}
pub fn add_lifetime(&mut self, lifetime: LifetimeId, kind: LifetimeKind) {
let id = self.lifetime_stores.insert(Cow::Owned(LifetimeStore {
kind,
fields: HashMap::new(),
lifetimes: [lifetime].iter().cloned().collect(),
}));
let old_store_id = self.lifetime_to_id.insert(lifetime, id);
if let Some(old_store_id) = old_store_id {
let old_lifetime_store = self.lifetime_stores.get_mut(old_store_id).unwrap().to_mut();
old_lifetime_store.lifetimes.remove(&lifetime);
if old_lifetime_store.lifetimes.is_empty() {
self.lifetime_stores.remove(old_store_id);
}
}
}
pub fn set_lifetime(&mut self, lifetime: LifetimeId, to: LifetimeId) {
let id = *self.lifetime_to_id.get(&to).unwrap();
let store = self.lifetime_stores.get_mut(id).unwrap();
store.to_mut().lifetimes.insert(lifetime);
let old_store_id = self.lifetime_to_id.insert(lifetime, id);
if let Some(old_store_id) = old_store_id {
let old_lifetime_store = self.lifetime_stores.get_mut(old_store_id).unwrap().to_mut();
old_lifetime_store.lifetimes.remove(&lifetime);
if old_lifetime_store.lifetimes.is_empty() {
self.lifetime_stores.remove(old_store_id);
}
}
}
fn unify(&mut self, lhs: LifetimeId, rhs: LifetimeId) {
use LifetimeKind::{ImpreciseLocal, PreciseLocal};
let lhs_id = *self.lifetime_to_id.get(&lhs).unwrap();
let rhs_id = *self.lifetime_to_id.get(&rhs).unwrap();
if lhs_id == rhs_id {
return;
}
let lhs_store = self.lifetime_stores.get(lhs_id).unwrap();
let rhs_store = self.lifetime_stores.get(rhs_id).unwrap();
let all_lifetimes: HashSet<_> =
lhs_store.lifetimes.union(&rhs_store.lifetimes).cloned().collect();
let result_kind = lhs_store.kind & rhs_store.kind;
let fields = if matches!(result_kind, PreciseLocal | ImpreciseLocal) {
let mut need_union = vec![];
let mut fields = lhs_store.fields.clone();
for (k, v) in rhs_store.fields.iter() {
if let Some(old) = fields.insert(*k, *v) {
need_union.push((old, *v));
}
}
drop(lhs_store);
drop(rhs_store);
for (lhs, rhs) in need_union {
self.unify(lhs, rhs);
}
fields
} else {
Default::default()
};
// unify them, slow
for lifetime in all_lifetimes.iter() {
self.lifetime_to_id.insert(*lifetime, lhs_id);
}
*self.lifetime_stores.get_mut(lhs_id).unwrap() =
Cow::Owned(LifetimeStore { kind: result_kind, fields, lifetimes: all_lifetimes });
self.lifetime_stores.remove(rhs_id);
}
fn get_field_lifetime(&mut self, obj: LifetimeId, field: StrRef) -> LifetimeId {
use LifetimeKind::*;
let id = *self.lifetime_to_id.get(&obj).unwrap();
let store = self.lifetime_stores.get(id).unwrap();
if matches!(store.kind, PreciseLocal | ImpreciseLocal) {
if let Some(&lifetime) = store.fields.get(&field) {
let field_lifetime_kind = self.get_lifetime_kind(lifetime);
if field_lifetime_kind == PreciseLocal
&& (store.kind == ImpreciseLocal || field == "$elem".into())
{
let id = *self.lifetime_to_id.get(&lifetime).unwrap();
self.lifetime_stores.get_mut(id).unwrap().to_mut().kind = ImpreciseLocal;
}
lifetime
} else {
LifetimeId(0)
}
} else {
obj
}
}
fn set_field_lifetime(
&mut self,
obj: LifetimeId,
field: StrRef,
field_lifetime: LifetimeId,
is_init: bool,
) -> Result<(), String> {
use LifetimeKind::*;
let obj_id = *self.lifetime_to_id.get(&obj).unwrap();
let field_id = *self.lifetime_to_id.get(&field_lifetime).unwrap();
let field_lifetime_kind = self.lifetime_stores.get(field_id).unwrap().kind;
let obj_store = self.lifetime_stores.get_mut(obj_id).unwrap();
if !matches!(
(obj_store.kind, field_lifetime_kind),
(PreciseLocal, _) | (ImpreciseLocal, _) | (_, Static)
) {
return Err("field lifetime error".into());
}
match obj_store.kind {
// $elem means list elements
PreciseLocal if field != "$elem".into() => {
// strong update
obj_store.to_mut().fields.insert(field, field_lifetime);
}
PreciseLocal | ImpreciseLocal => {
// weak update
let old_lifetime = obj_store.to_mut().fields.get(&field).copied();
if let Some(old_lifetime) = old_lifetime {
self.unify(old_lifetime, field_lifetime);
} else {
obj_store.to_mut().fields.insert(field, field_lifetime);
if !is_init {
// unify with unknown lifetime
self.unify(LifetimeId(0), field_lifetime);
}
if field == "$elem".into() {
let field_lifetime_id = *self.lifetime_to_id.get(&field_lifetime).unwrap();
let field_lifetime = self.lifetime_stores.get_mut(field_lifetime_id).unwrap();
if field_lifetime.kind == PreciseLocal {
field_lifetime.to_mut().kind = ImpreciseLocal;
}
}
}
}
_ => (),
}
Ok(())
}
fn get_lifetime_kind(&self, lifetime: LifetimeId) -> LifetimeKind {
self.lifetime_stores.get(*self.lifetime_to_id.get(&lifetime).unwrap()).unwrap().kind
}
fn pass_function_params(&mut self, lifetimes: &[LifetimeId]) {
use LifetimeKind::*;
let mut visited = HashSet::new();
let mut worklist = vec![];
fn add_fields_to_worklist(
visited: &mut HashSet<LifetimeId>,
worklist: &mut Vec<(LifetimeId, bool)>,
fields: &HashMap<StrRef, LifetimeId>,
) {
for (&name, &field) in fields.iter() {
if visited.insert(field) {
// not visited previously
let name = name.to_string();
let mutable = !(name.starts_with("$elem") && name.len() != "$elem".len());
worklist.push((field, mutable));
}
}
}
for lifetime in lifetimes.iter() {
let lifetime =
self.lifetime_stores.get_mut(*self.lifetime_to_id.get(lifetime).unwrap()).unwrap();
add_fields_to_worklist(&mut visited, &mut worklist, &lifetime.fields);
}
while let Some((item, mutable)) = worklist.pop() {
let lifetime =
self.lifetime_stores.get_mut(*self.lifetime_to_id.get(&item).unwrap()).unwrap();
if matches!(lifetime.kind, Unknown | Static) {
continue;
}
add_fields_to_worklist(&mut visited, &mut worklist, &lifetime.fields);
if mutable {
// we may assign values with static lifetime to function params
lifetime.to_mut().kind = lifetime.kind & Static;
}
}
}
pub fn analyze_basic_block<'b, I: Iterator<Item = (LifetimeId, &'b LifetimeIR, Location)>>(
&mut self,
instructions: I,
) -> Result<Option<&'b [BasicBlockId]>, String> {
use LifetimeIR::*;
for (id, inst, loc) in instructions {
match inst {
VarAssign { var, lifetime } => {
self.variable_assignment.insert(*var, *lifetime);
}
VarAccess { var } => {
let lifetime = self.variable_assignment.get(var).cloned();
if let Some(lifetime) = lifetime {
self.set_lifetime(id, lifetime);
} else {
// should be static lifetime
self.add_lifetime(id, LifetimeKind::Static)
}
}
FieldAssign { obj, field, new, is_init } => {
self.set_field_lifetime(*obj, *field, *new, *is_init)
.map_err(|e| format!("{} in {}", e, loc))?;
}
FieldAccess { obj, field } => {
let lifetime = self.get_field_lifetime(*obj, *field);
self.set_lifetime(id, lifetime);
}
CreateLifetime { kind } => {
if *kind == LifetimeKind::Unknown {
self.set_lifetime(id, LifetimeId(0));
} else {
self.add_lifetime(id, *kind);
}
}
PassedToFunc { param_lifetimes } => {
self.pass_function_params(param_lifetimes);
}
UnifyLifetimes { lifetimes } => {
assert!(!lifetimes.is_empty());
let lhs = lifetimes[0];
for rhs in lifetimes[1..].iter() {
self.unify(lhs, *rhs);
}
self.set_lifetime(id, lhs);
}
Return { val } => {
if let Some(val) = val {
let kind = self.get_lifetime_kind(*val);
if !matches!(kind, LifetimeKind::Static | LifetimeKind::NonLocal) {
return Err(format!("return value lifetime error in {}", loc));
}
}
return Ok(None);
}
Branch { targets } => return Ok(Some(targets)),
}
}
Ok(None)
}
}

View File

@ -1,580 +0,0 @@
use std::sync::Arc;
use itertools::chain;
use nac3parser::ast::{Comprehension, Constant, Expr, ExprKind, Location, Stmt, StmtKind, StrRef};
use lifetime::{BasicBlockId, LifetimeIR, LifetimeIRBuilder, LifetimeId, LifetimeKind};
use crate::{
symbol_resolver::SymbolResolver,
toplevel::{TopLevelContext, TopLevelDef},
};
use super::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier},
};
#[cfg(test)]
mod test;
mod lifetime;
pub struct EscapeAnalyzer<'a> {
builder: LifetimeIRBuilder,
loop_head: Option<BasicBlockId>,
loop_tail: Option<BasicBlockId>,
unifier: &'a mut Unifier,
primitive_store: &'a PrimitiveStore,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
top_level: &'a TopLevelContext,
}
impl<'a> EscapeAnalyzer<'a> {
pub fn new(
unifier: &'a mut Unifier,
primitive_store: &'a PrimitiveStore,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
top_level: &'a TopLevelContext,
) -> Self {
Self {
builder: LifetimeIRBuilder::new(),
loop_head: None,
loop_tail: None,
primitive_store,
unifier,
resolver,
top_level,
}
}
pub fn check_function_lifetime(
unifier: &'a mut Unifier,
primitive_store: &'a PrimitiveStore,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
top_level: &'a TopLevelContext,
args: &[(StrRef, Type)],
body: &[Stmt<Option<Type>>],
loc: Location,
) -> Result<(), String> {
use LifetimeIR::{CreateLifetime, VarAssign};
let mut zelf = Self::new(unifier, primitive_store, resolver, top_level);
let nonlocal_lifetime =
zelf.builder.append_ir(CreateLifetime { kind: LifetimeKind::NonLocal }, loc);
for (name, ty) in args.iter().copied() {
if zelf.need_alloca(ty) {
zelf.builder.append_ir(VarAssign { var: name, lifetime: nonlocal_lifetime }, loc);
}
}
zelf.handle_statements(body)?;
zelf.builder.remove_empty_bb();
zelf.builder.analyze().map_err(|e| {
format!("{}\nIR: {}", e, zelf.builder.print_ir())
})
}
fn need_alloca(&mut self, ty: Type) -> bool {
!(self.unifier.unioned(ty, self.primitive_store.int32)
|| self.unifier.unioned(ty, self.primitive_store.int64)
|| self.unifier.unioned(ty, self.primitive_store.uint32)
|| self.unifier.unioned(ty, self.primitive_store.uint64)
|| self.unifier.unioned(ty, self.primitive_store.float)
|| self.unifier.unioned(ty, self.primitive_store.bool)
|| self.unifier.unioned(ty, self.primitive_store.none)
|| self.unifier.unioned(ty, self.primitive_store.range))
}
fn is_terminated(&self) -> bool {
self.builder.is_terminated(self.builder.get_current_block())
}
fn handle_unknown_function_call<P: std::borrow::Borrow<Expr<Option<Type>>>>(
&mut self,
params: &[P],
ret_need_alloca: bool,
loc: Location,
) -> Result<Option<LifetimeId>, String> {
let param_lifetimes = params
.iter()
.filter_map(|p| self.handle_expr(p.borrow()).transpose())
.collect::<Result<Vec<_>, _>>()?;
self.builder.append_ir(LifetimeIR::PassedToFunc { param_lifetimes }, loc);
if ret_need_alloca {
Ok(Some(
self.builder
.append_ir(LifetimeIR::CreateLifetime { kind: LifetimeKind::Unknown }, loc),
))
} else {
Ok(None)
}
}
fn handle_expr(&mut self, expr: &Expr<Option<Type>>) -> Result<Option<LifetimeId>, String> {
use LifetimeIR::*;
use LifetimeKind::*;
let need_alloca = self.need_alloca(expr.custom.unwrap());
let loc = expr.location;
Ok(match &expr.node {
ExprKind::Name { id, .. } => {
if need_alloca {
Some(self.builder.append_ir(VarAccess { var: *id }, loc))
} else {
None
}
}
ExprKind::Attribute { value, attr, .. } => {
if need_alloca {
let val = self.handle_expr(value)?.unwrap();
Some(self.builder.append_ir(FieldAccess { obj: val, field: *attr }, loc))
} else {
self.handle_expr(value)?;
None
}
}
ExprKind::Constant { .. } => {
if need_alloca {
Some(self.builder.append_ir(CreateLifetime { kind: Static }, loc))
} else {
None
}
}
ExprKind::List { elts, .. } => {
let elems =
elts.iter().map(|e| self.handle_expr(e)).collect::<Result<Vec<_>, _>>()?;
let list_lifetime =
self.builder.append_ir(CreateLifetime { kind: PreciseLocal }, loc);
if !elems.is_empty() {
if elems[0].is_some() {
let elems = elems.into_iter().map(|e| e.unwrap()).collect::<Vec<_>>();
let elem_lifetime =
self.builder.append_ir(UnifyLifetimes { lifetimes: elems }, loc);
self.builder.append_ir(
FieldAssign {
obj: list_lifetime,
field: "$elem".into(),
new: elem_lifetime,
is_init: true,
},
loc,
);
}
} else {
let elem_lifetime =
self.builder.append_ir(CreateLifetime { kind: PreciseLocal }, loc);
self.builder.append_ir(
FieldAssign {
obj: list_lifetime,
field: "$elem".into(),
new: elem_lifetime,
is_init: true,
},
loc,
);
}
Some(list_lifetime)
}
ExprKind::Tuple { elts, .. } => {
let elems =
elts.iter().map(|e| self.handle_expr(e)).collect::<Result<Vec<_>, _>>()?;
let tuple_lifetime =
self.builder.append_ir(CreateLifetime { kind: PreciseLocal }, loc);
for (i, lifetime) in elems.into_iter().enumerate() {
if let Some(lifetime) = lifetime {
self.builder.append_ir(
FieldAssign {
obj: tuple_lifetime,
field: format!("$elem{}", i).into(),
new: lifetime,
is_init: true,
},
loc,
);
}
}
Some(tuple_lifetime)
}
ExprKind::Subscript { value, slice, .. } => {
let value_lifetime = self.handle_expr(value)?.unwrap();
match &slice.node {
ExprKind::Slice { lower, upper, step } => {
for expr in [lower, upper, step].iter().filter_map(|x| x.as_ref()) {
self.handle_expr(expr)?;
}
let slice_lifetime =
self.builder.append_ir(CreateLifetime { kind: PreciseLocal }, loc);
let slice_elem = self.builder.append_ir(
FieldAccess { obj: value_lifetime, field: "$elem".into() },
loc,
);
self.builder.append_ir(
FieldAssign {
obj: slice_lifetime,
field: "$elem".into(),
new: slice_elem,
is_init: true
},
loc,
);
Some(slice_lifetime)
}
ExprKind::Constant { value: Constant::Int(v), .. }
if matches!(
&*self.unifier.get_ty(value.custom.unwrap()),
TypeEnum::TTuple { .. }
) =>
{
Some(self.builder.append_ir(
FieldAccess {
obj: value_lifetime,
field: format!("$elem{}", v).into(),
},
loc,
))
}
_ => {
self.handle_expr(slice)?;
if need_alloca {
Some(self.builder.append_ir(
FieldAccess { obj: value_lifetime, field: "$elem".into() },
loc,
))
} else {
None
}
}
}
}
ExprKind::Call { func, args, keywords } => {
let mut lifetimes = vec![];
for arg in chain!(args.iter(), keywords.iter().map(|k| k.node.value.as_ref())) {
if let Some(lifetime) = self.handle_expr(arg)? {
lifetimes.push(lifetime);
}
}
match &func.node {
ExprKind::Name { id, .. } => {
if !lifetimes.is_empty() {
self.builder.append_ir(PassedToFunc { param_lifetimes: lifetimes }, loc);
}
if need_alloca {
let id = self
.resolver
.get_identifier_def(*id)
.map_err(|e| format!("{} (at {})", e, func.location))?;
if let TopLevelDef::Class { .. } =
&*self.top_level.definitions.read()[id.0].read()
{
Some(
self.builder
.append_ir(CreateLifetime { kind: PreciseLocal }, loc),
)
} else {
Some(self.builder.append_ir(CreateLifetime { kind: Unknown }, loc))
}
} else {
None
}
}
ExprKind::Attribute { value, .. } => {
let obj_lifetime = self.handle_expr(value)?.unwrap();
lifetimes.push(obj_lifetime);
self.builder.append_ir(PassedToFunc { param_lifetimes: lifetimes }, loc);
if need_alloca {
Some(self.builder.append_ir(CreateLifetime { kind: Unknown }, loc))
} else {
None
}
}
_ => unimplemented!(),
}
}
ExprKind::BinOp { left, right, .. } => self.handle_unknown_function_call(
&[left.as_ref(), right.as_ref()],
need_alloca,
loc,
)?,
ExprKind::BoolOp { values, .. } => {
self.handle_unknown_function_call(&values, need_alloca, loc)?
}
ExprKind::UnaryOp { operand, .. } => {
self.handle_unknown_function_call(&[operand.as_ref()], need_alloca, loc)?
}
ExprKind::Compare { left, comparators, .. } => {
self.handle_unknown_function_call(&[left.as_ref()], false, loc)?;
self.handle_unknown_function_call(&comparators, need_alloca, loc)?
}
ExprKind::IfExp { test, body, orelse } => {
self.handle_expr(test)?;
let body_bb = self.builder.append_block();
let else_bb = self.builder.append_block();
let tail_bb = self.builder.append_block();
self.builder.append_ir(Branch { targets: vec![body_bb, else_bb] }, test.location);
self.builder.position_at_end(body_bb);
let body_lifetime = self.handle_expr(body)?;
self.builder.append_ir(Branch { targets: vec![tail_bb] }, body.location);
self.builder.position_at_end(else_bb);
let else_lifetime = self.handle_expr(body)?;
self.builder.append_ir(Branch { targets: vec![tail_bb] }, orelse.location);
self.builder.position_at_end(tail_bb);
if let (Some(body_lifetime), Some(else_lifetime)) = (body_lifetime, else_lifetime) {
Some(self.builder.append_ir(
UnifyLifetimes { lifetimes: vec![body_lifetime, else_lifetime] },
loc,
))
} else {
None
}
}
ExprKind::ListComp { elt, generators } => {
let Comprehension { target, iter, ifs, .. } = &generators[0];
let list_lifetime =
self.builder.append_ir(CreateLifetime { kind: PreciseLocal }, loc);
let iter_elem_lifetime = self.handle_expr(iter)?.map(|obj| {
self.builder
.append_ir(FieldAccess { obj, field: "$elem".into() }, iter.location)
});
let loop_body = self.builder.append_block();
let loop_tail = self.builder.append_block();
self.builder.append_ir(Branch { targets: vec![loop_body] }, loc);
self.builder.position_at_end(loop_body);
self.handle_assignment(target, iter_elem_lifetime)?;
for ifexpr in ifs.iter() {
self.handle_expr(ifexpr)?;
}
let elem_lifetime = self.handle_expr(elt)?;
if let Some(elem_lifetime) = elem_lifetime {
self.builder.append_ir(
FieldAssign {
obj: list_lifetime,
field: "$elem".into(),
new: elem_lifetime,
is_init: true
},
elt.location,
);
}
self.builder.append_ir(Branch { targets: vec![loop_body, loop_tail] }, loc);
self.builder.position_at_end(loop_tail);
Some(list_lifetime)
}
_ => unimplemented!(),
})
}
fn handle_assignment(
&mut self,
lhs: &Expr<Option<Type>>,
rhs_lifetime: Option<LifetimeId>,
) -> Result<(), String> {
use LifetimeIR::*;
match &lhs.node {
ExprKind::Attribute { value, attr, .. } => {
let value_lifetime = self.handle_expr(value)?.unwrap();
if let Some(field_lifetime) = rhs_lifetime {
self.builder.append_ir(
FieldAssign { obj: value_lifetime, field: *attr, new: field_lifetime, is_init: false },
lhs.location,
);
}
}
ExprKind::Subscript { value, slice, .. } => {
let value_lifetime = self.handle_expr(value)?.unwrap();
let elem_lifetime = if let ExprKind::Slice { lower, upper, step } = &slice.node {
for expr in [lower, upper, step].iter().filter_map(|x| x.as_ref()) {
self.handle_expr(expr)?;
}
if let Some(rhs_lifetime) = rhs_lifetime {
// must be a list
Some(self.builder.append_ir(
FieldAccess { obj: rhs_lifetime, field: "$elem".into() },
lhs.location,
))
} else {
None
}
} else {
self.handle_expr(slice)?;
rhs_lifetime
};
// must be a list
if let Some(elem_lifetime) = elem_lifetime {
self.builder.append_ir(
FieldAssign {
obj: value_lifetime,
field: "$elem".into(),
new: elem_lifetime,
is_init: false
},
lhs.location,
);
}
}
ExprKind::Name { id, .. } => {
if let Some(lifetime) = rhs_lifetime {
self.builder.append_ir(VarAssign { var: *id, lifetime }, lhs.location);
}
}
ExprKind::Tuple { elts, .. } => {
let rhs_lifetime = rhs_lifetime.unwrap();
for (i, e) in elts.iter().enumerate() {
let elem_lifetime = self.builder.append_ir(
FieldAccess { obj: rhs_lifetime, field: format!("$elem{}", i).into() },
e.location,
);
self.handle_assignment(e, Some(elem_lifetime))?;
}
}
_ => unreachable!(),
}
Ok(())
}
fn handle_statement(&mut self, stmt: &Stmt<Option<Type>>) -> Result<(), String> {
use LifetimeIR::*;
match &stmt.node {
StmtKind::Expr { value, .. } => {
self.handle_expr(value)?;
}
StmtKind::Assign { targets, value, .. } => {
let rhs_lifetime = self.handle_expr(value)?;
for target in targets {
self.handle_assignment(target, rhs_lifetime)?;
}
}
StmtKind::If { test, body, orelse, .. } => {
// test should return bool
self.handle_expr(test)?;
let body_bb = self.builder.append_block();
let else_bb = self.builder.append_block();
self.builder.append_ir(Branch { targets: vec![body_bb, else_bb] }, stmt.location);
self.builder.position_at_end(body_bb);
self.handle_statements(&body)?;
let body_terminated = self.is_terminated();
if orelse.is_empty() {
if !body_terminated {
// else_bb is the basic block after this if statement
self.builder.append_ir(Branch { targets: vec![else_bb] }, stmt.location);
self.builder.position_at_end(else_bb);
}
} else {
let tail_bb = self.builder.append_block();
if !body_terminated {
self.builder.append_ir(Branch { targets: vec![tail_bb] }, stmt.location);
}
self.builder.position_at_end(else_bb);
self.handle_statements(&orelse)?;
if !self.is_terminated() {
self.builder.append_ir(Branch { targets: vec![tail_bb] }, stmt.location);
}
self.builder.position_at_end(tail_bb);
}
}
StmtKind::While { test, body, orelse, .. } => {
let old_loop_head = self.loop_head;
let old_loop_tail = self.loop_tail;
let loop_head = self.builder.append_block();
let loop_body = self.builder.append_block();
let loop_else =
if orelse.is_empty() { None } else { Some(self.builder.append_block()) };
let loop_tail = self.builder.append_block();
self.loop_head = Some(loop_head);
self.loop_tail = Some(loop_tail);
self.builder.append_ir(Branch { targets: vec![loop_head] }, stmt.location);
self.builder.position_at_end(loop_head);
self.handle_expr(test)?;
self.builder.append_ir(
Branch { targets: vec![loop_body, loop_else.unwrap_or(loop_tail)] },
stmt.location,
);
self.builder.position_at_end(loop_body);
self.handle_statements(&body)?;
if !self.is_terminated() {
self.builder.append_ir(Branch { targets: vec![loop_head] }, stmt.location);
}
self.loop_head = old_loop_head;
self.loop_tail = old_loop_tail;
if let Some(loop_else) = loop_else {
self.builder.position_at_end(loop_else);
self.handle_statements(&orelse)?;
if !self.is_terminated() {
self.builder.append_ir(Branch { targets: vec![loop_tail] }, stmt.location);
}
}
self.builder.position_at_end(loop_tail);
}
StmtKind::For { target, iter, body, orelse, .. } => {
let old_loop_head = self.loop_head;
let old_loop_tail = self.loop_tail;
let loop_head = self.builder.append_block();
let loop_body = self.builder.append_block();
let loop_else =
if orelse.is_empty() { None } else { Some(self.builder.append_block()) };
let loop_tail = self.builder.append_block();
self.loop_head = Some(loop_head);
self.loop_tail = Some(loop_tail);
let iter_lifetime = self.handle_expr(iter)?.map(|obj| {
self.builder
.append_ir(FieldAccess { obj, field: "$elem".into() }, iter.location)
});
self.builder.append_ir(Branch { targets: vec![loop_head] }, stmt.location);
self.builder.position_at_end(loop_head);
if let Some(iter_lifetime) = iter_lifetime {
self.handle_assignment(target, Some(iter_lifetime))?;
}
self.builder.append_ir(
Branch { targets: vec![loop_body, loop_else.unwrap_or(loop_tail)] },
stmt.location,
);
self.builder.position_at_end(loop_body);
self.handle_statements(&body)?;
if !self.is_terminated() {
self.builder.append_ir(Branch { targets: vec![loop_head] }, stmt.location);
}
self.loop_head = old_loop_head;
self.loop_tail = old_loop_tail;
if let Some(loop_else) = loop_else {
self.builder.position_at_end(loop_else);
self.handle_statements(&orelse)?;
if !self.is_terminated() {
self.builder.append_ir(Branch { targets: vec![loop_tail] }, stmt.location);
}
}
self.builder.position_at_end(loop_tail);
}
StmtKind::Continue { .. } => {
if let Some(loop_head) = self.loop_head {
self.builder.append_ir(Branch { targets: vec![loop_head] }, stmt.location);
} else {
return Err(format!("break outside loop"));
}
}
StmtKind::Break { .. } => {
if let Some(loop_tail) = self.loop_tail {
self.builder.append_ir(Branch { targets: vec![loop_tail] }, stmt.location);
} else {
return Err(format!("break outside loop"));
}
}
StmtKind::Return { value, .. } => {
let val = if let Some(value) = value { self.handle_expr(value)? } else { None };
self.builder.append_ir(Return { val }, stmt.location);
}
StmtKind::Pass { .. } => {}
_ => unimplemented!("{:?}", stmt.node),
}
Ok(())
}
fn handle_statements(&mut self, stmts: &[Stmt<Option<Type>>]) -> Result<(), String> {
for stmt in stmts.iter() {
if self.builder.is_terminated(self.builder.get_current_block()) {
break;
}
self.handle_statement(stmt)?;
}
Ok(())
}
}

View File

@ -1,60 +0,0 @@
use super::EscapeAnalyzer;
use crate::typecheck::{type_inferencer::test::TestEnvironment, typedef::TypeEnum};
use indoc::indoc;
use nac3parser::ast::fold::Fold;
use std::collections::hash_set::HashSet;
use test_case::test_case;
use nac3parser::parser::parse_program;
#[test_case(indoc! {"
# a: list[list[int32]]
b = [1]
a[0] = b
"}, Err("field lifetime error in unknown: line 3 column 2".into())
; "assign global elem")]
#[test_case(indoc! {"
# a: list[list[int32]]
b = [[], []]
b[1] = a
b[0][0] = [0]
"}, Err("field lifetime error in unknown: line 4 column 5".into())
; "global unify")]
#[test_case(indoc! {"
b = [1, 2, 3]
c = [a]
c[0][0] = b
"}, Err("field lifetime error in unknown: line 3 column 5".into())
; "global unify 2")]
fn test_simple(source: &str, expected_result: Result<(), String>) {
let mut env = TestEnvironment::basic_test_env();
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect();
defined_identifiers.insert("a".into());
let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone();
let list_int = inferencer.unifier.add_ty(TypeEnum::TList { ty: inferencer.primitives.int32 });
let list_list_int = inferencer.unifier.add_ty(TypeEnum::TList { ty: list_int });
inferencer.variable_mapping.insert("a".into(), list_list_int);
let statements = parse_program(source, Default::default()).unwrap();
let statements = statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
let mut lifetime_ctx = EscapeAnalyzer::new(
&mut inferencer.unifier,
&mut inferencer.primitives,
inferencer.function_data.resolver.clone(),
&inferencer.top_level,
);
lifetime_ctx.handle_statements(&statements).unwrap();
lifetime_ctx.builder.remove_empty_bb();
let result = lifetime_ctx.builder.analyze();
assert_eq!(result, expected_result);
}

View File

@ -1,14 +1,24 @@
use crate::typecheck::typedef::TypeEnum; use std::{
collections::{HashMap, HashSet},
iter::once,
};
use super::type_inferencer::Inferencer; use nac3parser::ast::{
use super::typedef::Type; self, Constant, Expr, ExprKind,
use nac3parser::ast::{self, Expr, ExprKind, Stmt, StmtKind, StrRef}; Operator::{LShift, RShift},
use std::{collections::HashSet, iter::once}; Stmt, StmtKind, StrRef,
};
use super::{
type_inferencer::{DeclarationSource, IdentifierInfo, Inferencer},
typedef::{Type, TypeEnum},
};
use crate::toplevel::helper::PrimDef;
impl<'a> Inferencer<'a> { impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), String> { fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) { if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) {
Err(format!("Error at {}: cannot have value none", expr.location)) Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)]))
} else { } else {
Ok(()) Ok(())
} }
@ -17,40 +27,61 @@ impl<'a> Inferencer<'a> {
fn check_pattern( fn check_pattern(
&mut self, &mut self,
pattern: &Expr<Option<Type>>, pattern: &Expr<Option<Type>>,
defined_identifiers: &mut HashSet<StrRef>, defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
) -> Result<(), String> { ) -> Result<(), HashSet<String>> {
match &pattern.node { match &pattern.node {
ast::ExprKind::Name { id, .. } if id == &"none".into() => ExprKind::Name { id, .. } if id == &"none".into() => {
Err(format!("cannot assign to a `none` (at {})", pattern.location)), Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
}
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) { // If `id` refers to a declared symbol, reject this assignment if it is used in the
defined_identifiers.insert(*id); // context of an (implicit) global variable
if let Some(id_info) = defined_identifiers.get(id) {
if matches!(
id_info.source,
DeclarationSource::Global { is_explicit: Some(false) }
) {
return Err(HashSet::from([format!(
"cannot access local variable '{id}' before it is declared (at {})",
pattern.location
)]));
}
}
if !defined_identifiers.contains_key(id) {
defined_identifiers.insert(*id, IdentifierInfo::default());
} }
self.should_have_value(pattern)?; self.should_have_value(pattern)?;
Ok(()) Ok(())
} }
ExprKind::Tuple { elts, .. } => { ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() { for elt in elts {
self.check_pattern(elt, defined_identifiers)?; self.check_pattern(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }
Ok(()) Ok(())
} }
ExprKind::Starred { value, .. } => {
self.check_pattern(value, defined_identifiers)?;
self.should_have_value(value)?;
Ok(())
}
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value, defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?; self.should_have_value(value)?;
self.check_expr(slice, defined_identifiers)?; self.check_expr(slice, defined_identifiers)?;
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return Err(format!( return Err(HashSet::from([format!(
"Error at {}: cannot assign to tuple element", "Error at {}: cannot assign to tuple element",
value.location value.location
)); )]));
} }
Ok(()) Ok(())
} }
ExprKind::Constant { .. } => { ExprKind::Constant { .. } => Err(HashSet::from([format!(
Err(format!("cannot assign to a constant (at {})", pattern.location)) "cannot assign to a constant (at {})",
} pattern.location
)])),
_ => self.check_expr(pattern, defined_identifiers), _ => self.check_expr(pattern, defined_identifiers),
} }
} }
@ -58,16 +89,19 @@ impl<'a> Inferencer<'a> {
fn check_expr( fn check_expr(
&mut self, &mut self,
expr: &Expr<Option<Type>>, expr: &Expr<Option<Type>>,
defined_identifiers: &mut HashSet<StrRef>, defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
) -> Result<(), String> { ) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
if !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. })
return Err(format!( && !ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::List.id())
&& !self.unifier.is_concrete(*ty, &self.function_data.bound_variables)
{
return Err(HashSet::from([format!(
"expected concrete type at {} but got {}", "expected concrete type at {} but got {}",
expr.location, expr.location,
self.unifier.get_ty(*ty).get_type_name() self.unifier.get_ty(*ty).get_type_name()
)); )]));
} }
} }
match &expr.node { match &expr.node {
@ -76,7 +110,7 @@ impl<'a> Inferencer<'a> {
return Ok(()); return Ok(());
} }
self.should_have_value(expr)?; self.should_have_value(expr)?;
if !defined_identifiers.contains(id) { if !defined_identifiers.contains_key(id) {
match self.function_data.resolver.get_symbol_type( match self.function_data.resolver.get_symbol_type(
self.unifier, self.unifier,
&self.top_level.definitions.read(), &self.top_level.definitions.read(),
@ -84,13 +118,28 @@ impl<'a> Inferencer<'a> {
*id, *id,
) { ) {
Ok(_) => { Ok(_) => {
self.defined_identifiers.insert(*id); let is_global = self.is_id_global(*id);
defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => {
DeclarationSource::Global { is_explicit: Some(false) }
}
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
} }
Err(e) => { Err(e) => {
return Err(format!( return Err(HashSet::from([format!(
"type error at identifier `{}` ({}) at {}", "type error at identifier `{}` ({}) at {}",
id, e, expr.location id, e, expr.location
)); )]))
} }
} }
} }
@ -98,7 +147,7 @@ impl<'a> Inferencer<'a> {
ExprKind::List { elts, .. } ExprKind::List { elts, .. }
| ExprKind::Tuple { elts, .. } | ExprKind::Tuple { elts, .. }
| ExprKind::BoolOp { values: elts, .. } => { | ExprKind::BoolOp { values: elts, .. } => {
for elt in elts.iter() { for elt in elts {
self.check_expr(elt, defined_identifiers)?; self.check_expr(elt, defined_identifiers)?;
self.should_have_value(elt)?; self.should_have_value(elt)?;
} }
@ -107,11 +156,25 @@ impl<'a> Inferencer<'a> {
self.check_expr(value, defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?; self.should_have_value(value)?;
} }
ExprKind::BinOp { left, right, .. } => { ExprKind::BinOp { left, op, right } => {
self.check_expr(left, defined_identifiers)?; self.check_expr(left, defined_identifiers)?;
self.check_expr(right, defined_identifiers)?; self.check_expr(right, defined_identifiers)?;
self.should_have_value(left)?; self.should_have_value(left)?;
self.should_have_value(right)?; self.should_have_value(right)?;
// Check whether a bitwise shift has a negative RHS constant value
if *op == LShift || *op == RShift {
if let ExprKind::Constant { value, .. } = &right.node {
let Constant::Int(rhs_val) = value else { unreachable!() };
if *rhs_val < 0 {
return Err(HashSet::from([format!(
"shift count is negative at {}",
right.location
)]));
}
}
}
} }
ExprKind::UnaryOp { operand, .. } => { ExprKind::UnaryOp { operand, .. } => {
self.check_expr(operand, defined_identifiers)?; self.check_expr(operand, defined_identifiers)?;
@ -141,11 +204,9 @@ impl<'a> Inferencer<'a> {
} }
ExprKind::Lambda { args, body } => { ExprKind::Lambda { args, body } => {
let mut defined_identifiers = defined_identifiers.clone(); let mut defined_identifiers = defined_identifiers.clone();
for arg in args.args.iter() { for arg in &args.args {
// TODO: should we check the types here? // TODO: should we check the types here?
if !defined_identifiers.contains(&arg.node.arg) { defined_identifiers.entry(arg.node.arg).or_default();
defined_identifiers.insert(arg.node.arg);
}
} }
self.check_expr(body, &mut defined_identifiers)?; self.check_expr(body, &mut defined_identifiers)?;
} }
@ -179,24 +240,49 @@ impl<'a> Inferencer<'a> {
Ok(()) Ok(())
} }
/// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
///
/// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
/// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
if cfg!(feature = "no-escape-analysis") {
true
} else {
match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => [
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
]
.iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty, .. } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
}
}
// check statements for proper identifier def-use and return on all paths // check statements for proper identifier def-use and return on all paths
fn check_stmt( fn check_stmt(
&mut self, &mut self,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut HashSet<StrRef>, defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
) -> Result<bool, String> { ) -> Result<bool, HashSet<String>> {
match &stmt.node { match &stmt.node {
StmtKind::For { target, iter, body, orelse, .. } => { StmtKind::For { target, iter, body, orelse, .. } => {
self.check_expr(iter, defined_identifiers)?; self.check_expr(iter, defined_identifiers)?;
self.should_have_value(iter)?; self.should_have_value(iter)?;
let mut local_defined_identifiers = defined_identifiers.clone(); let mut local_defined_identifiers = defined_identifiers.clone();
for stmt in orelse.iter() { for stmt in orelse {
self.check_stmt(stmt, &mut local_defined_identifiers)?; self.check_stmt(stmt, &mut local_defined_identifiers)?;
} }
let mut local_defined_identifiers = defined_identifiers.clone(); let mut local_defined_identifiers = defined_identifiers.clone();
self.check_pattern(target, &mut local_defined_identifiers)?; self.check_pattern(target, &mut local_defined_identifiers)?;
self.should_have_value(target)?; self.should_have_value(target)?;
for stmt in body.iter() { for stmt in body {
self.check_stmt(stmt, &mut local_defined_identifiers)?; self.check_stmt(stmt, &mut local_defined_identifiers)?;
} }
Ok(false) Ok(false)
@ -209,9 +295,11 @@ impl<'a> Inferencer<'a> {
let body_returned = self.check_block(body, &mut body_identifiers)?; let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?; let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.iter() { for ident in body_identifiers.keys() {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) { if !defined_identifiers.contains_key(ident)
defined_identifiers.insert(*ident); && orelse_identifiers.contains_key(ident)
{
defined_identifiers.insert(*ident, IdentifierInfo::default());
} }
} }
Ok(body_returned && orelse_returned) Ok(body_returned && orelse_returned)
@ -226,7 +314,7 @@ impl<'a> Inferencer<'a> {
} }
StmtKind::With { items, body, .. } => { StmtKind::With { items, body, .. } => {
let mut new_defined_identifiers = defined_identifiers.clone(); let mut new_defined_identifiers = defined_identifiers.clone();
for item in items.iter() { for item in items {
self.check_expr(&item.context_expr, defined_identifiers)?; self.check_expr(&item.context_expr, defined_identifiers)?;
if let Some(var) = item.optional_vars.as_ref() { if let Some(var) = item.optional_vars.as_ref() {
self.check_pattern(var, &mut new_defined_identifiers)?; self.check_pattern(var, &mut new_defined_identifiers)?;
@ -238,11 +326,11 @@ impl<'a> Inferencer<'a> {
StmtKind::Try { body, handlers, orelse, finalbody, .. } => { StmtKind::Try { body, handlers, orelse, finalbody, .. } => {
self.check_block(body, &mut defined_identifiers.clone())?; self.check_block(body, &mut defined_identifiers.clone())?;
self.check_block(orelse, &mut defined_identifiers.clone())?; self.check_block(orelse, &mut defined_identifiers.clone())?;
for handler in handlers.iter() { for handler in handlers {
let mut defined_identifiers = defined_identifiers.clone(); let mut defined_identifiers = defined_identifiers.clone();
let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node; let ast::ExcepthandlerKind::ExceptHandler { name, body, .. } = &handler.node;
if let Some(name) = name { if let Some(name) = name {
defined_identifiers.insert(*name); defined_identifiers.insert(*name, IdentifierInfo::default());
} }
self.check_block(body, &mut defined_identifiers)?; self.check_block(body, &mut defined_identifiers)?;
} }
@ -273,6 +361,30 @@ impl<'a> Inferencer<'a> {
if let Some(value) = value { if let Some(value) = value {
self.check_expr(value, defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.should_have_value(value)?; self.should_have_value(value)?;
// Check that the return value is a non-`alloca` type, effectively only allowing primitive types.
// This is a workaround preventing the caller from using a variable `alloca`-ed in the body, which
// is freed when the function returns.
if let Some(ret_ty) = value.custom {
// Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually
// inferred and just generates an unconditional assertion
if matches!(
value.node,
ExprKind::Constant { value: Constant::Ellipsis, .. }
) {
return Ok(true);
}
if !self.check_return_value_ty(ret_ty) {
return Err(HashSet::from([
format!(
"return value of type {} must be a primitive or a tuple of primitives at {}",
self.unifier.stringify(ret_ty),
value.location,
),
]));
}
}
} }
Ok(true) Ok(true)
} }
@ -282,6 +394,44 @@ impl<'a> Inferencer<'a> {
} }
Ok(true) Ok(true)
} }
StmtKind::Global { names, .. } => {
for id in names {
if let Some(id_info) = defined_identifiers.get(id) {
if id_info.source == DeclarationSource::Local {
return Err(HashSet::from([format!(
"name '{id}' is referenced prior to global declaration at {}",
stmt.location,
)]));
}
continue;
}
match self.function_data.resolver.get_symbol_type(
self.unifier,
&self.top_level.definitions.read(),
self.primitives,
*id,
) {
Ok(_) => {
defined_identifiers.insert(
*id,
IdentifierInfo {
source: DeclarationSource::Global { is_explicit: Some(true) },
},
);
}
Err(e) => {
return Err(HashSet::from([format!(
"type error at identifier `{}` ({}) at {}",
id, e, stmt.location
)]))
}
}
}
Ok(false)
}
// break, raise, etc. // break, raise, etc.
_ => Ok(false), _ => Ok(false),
} }
@ -290,12 +440,12 @@ impl<'a> Inferencer<'a> {
pub fn check_block( pub fn check_block(
&mut self, &mut self,
block: &[Stmt<Option<Type>>], block: &[Stmt<Option<Type>>],
defined_identifiers: &mut HashSet<StrRef>, defined_identifiers: &mut HashMap<StrRef, IdentifierInfo>,
) -> Result<bool, String> { ) -> Result<bool, HashSet<String>> {
let mut ret = false; let mut ret = false;
for stmt in block { for stmt in block {
if ret { if ret {
return Err(format!("dead code at {:?}", stmt.location)); eprintln!("warning: dead code at {}\n", stmt.location);
} }
if self.check_stmt(stmt, defined_identifiers)? { if self.check_stmt(stmt, defined_identifiers)? {
ret = true; ret = true;

View File

@ -1,69 +1,154 @@
use crate::typecheck::{ use std::{cmp::max, collections::HashMap, rc::Rc};
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, use itertools::{iproduct, Itertools};
use strum::IntoEnumIterator;
use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop};
use crate::{
symbol_resolver::SymbolValue,
toplevel::{
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
},
}; };
use nac3parser::ast::{self, StrRef};
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::collections::HashMap;
use std::rc::Rc;
pub fn binop_name(op: &Operator) -> &'static str { /// The variant of a binary operator.
match op { #[derive(Debug, Clone, Copy, PartialEq, Eq)]
Operator::Add => "__add__", pub enum BinopVariant {
Operator::Sub => "__sub__", /// The normal variant.
Operator::Div => "__truediv__", /// For addition, it would be `+`.
Operator::Mod => "__mod__", Normal,
Operator::Mult => "__mul__", /// The "Augmented Assigning Operator" variant.
Operator::Pow => "__pow__", /// For addition, it would be `+=`.
Operator::BitOr => "__or__", AugAssign,
Operator::BitXor => "__xor__", }
Operator::BitAnd => "__and__",
Operator::LShift => "__lshift__", /// A binary operator with its variant.
Operator::RShift => "__rshift__", #[derive(Debug, Clone, Copy)]
Operator::FloorDiv => "__floordiv__", pub struct Binop {
Operator::MatMult => "__matmul__", /// The base [`Operator`] of this binary operator.
pub base: Operator,
/// The variant of this binary operator.
pub variant: BinopVariant,
}
impl Binop {
/// Make a [`Binop`] of the normal variant from an [`Operator`].
#[must_use]
pub fn normal(base: Operator) -> Self {
Binop { base, variant: BinopVariant::Normal }
}
/// Make a [`Binop`] of the aug assign variant from an [`Operator`].
#[must_use]
pub fn aug_assign(base: Operator) -> Self {
Binop { base, variant: BinopVariant::AugAssign }
} }
} }
pub fn binop_assign_name(op: &Operator) -> &'static str { /// Details about an operator (unary, binary, etc...) in Python
match op { #[derive(Debug, Clone, Copy)]
Operator::Add => "__iadd__", pub struct OpInfo {
Operator::Sub => "__isub__", /// The method name of the binary operator.
Operator::Div => "__itruediv__", /// For addition, this would be `__add__`, and `__iadd__` if
Operator::Mod => "__imod__", /// it is the augmented assigning variant.
Operator::Mult => "__imul__", pub method_name: &'static str,
Operator::Pow => "__ipow__", /// The symbol of the binary operator.
Operator::BitOr => "__ior__", /// For addition, this would be `+`, and `+=` if
Operator::BitXor => "__ixor__", /// it is the augmented assigning variant.
Operator::BitAnd => "__iand__", pub symbol: &'static str,
Operator::LShift => "__ilshift__",
Operator::RShift => "__irshift__",
Operator::FloorDiv => "__ifloordiv__",
Operator::MatMult => "__imatmul__",
}
} }
pub fn unaryop_name(op: &Unaryop) -> &'static str { /// Helper macro to conveniently build an [`OpInfo`].
match op { ///
Unaryop::UAdd => "__pos__", /// Example usage: `make_info("add", "+")` generates `OpInfo { name: "__add__", symbol: "+" }`
Unaryop::USub => "__neg__", macro_rules! make_op_info {
Unaryop::Not => "__not__", ($name:expr, $symbol:expr) => {
Unaryop::Invert => "__inv__", OpInfo { method_name: concat!("__", $name, "__"), symbol: $symbol }
} };
} }
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> { pub trait HasOpInfo {
fn op_info(&self) -> OpInfo;
}
fn try_get_cmpop_info(op: Cmpop) -> Option<OpInfo> {
match op { match op {
Cmpop::Lt => Some("__lt__"), Cmpop::Lt => Some(make_op_info!("lt", "<")),
Cmpop::LtE => Some("__le__"), Cmpop::LtE => Some(make_op_info!("le", "<=")),
Cmpop::Gt => Some("__gt__"), Cmpop::Gt => Some(make_op_info!("gt", ">")),
Cmpop::GtE => Some("__ge__"), Cmpop::GtE => Some(make_op_info!("ge", ">=")),
Cmpop::Eq => Some("__eq__"), Cmpop::Eq => Some(make_op_info!("eq", "==")),
Cmpop::NotEq => Some("__ne__"), Cmpop::NotEq => Some(make_op_info!("ne", "!=")),
_ => None, _ => None,
} }
} }
impl OpInfo {
#[must_use]
pub fn supports_cmpop(op: Cmpop) -> bool {
try_get_cmpop_info(op).is_some()
}
}
impl HasOpInfo for Cmpop {
fn op_info(&self) -> OpInfo {
try_get_cmpop_info(*self).expect("{self:?} is not supported")
}
}
impl HasOpInfo for Binop {
fn op_info(&self) -> OpInfo {
// Helper macro to generate both the normal variant [`OpInfo`] and the
// augmented assigning variant [`OpInfo`] for a binary operator conveniently.
macro_rules! info {
($name:literal, $symbol:literal) => {
(
make_op_info!($name, $symbol),
make_op_info!(concat!("i", $name), concat!($symbol, "=")),
)
};
}
let (normal_variant, aug_assign_variant) = match self.base {
Operator::Add => info!("add", "+"),
Operator::Sub => info!("sub", "-"),
Operator::Div => info!("truediv", "/"),
Operator::Mod => info!("mod", "%"),
Operator::Mult => info!("mul", "*"),
Operator::Pow => info!("pow", "**"),
Operator::BitOr => info!("or", "|"),
Operator::BitXor => info!("xor", "^"),
Operator::BitAnd => info!("and", "&"),
Operator::LShift => info!("lshift", "<<"),
Operator::RShift => info!("rshift", ">>"),
Operator::FloorDiv => info!("floordiv", "//"),
Operator::MatMult => info!("matmul", "@"),
};
match self.variant {
BinopVariant::Normal => normal_variant,
BinopVariant::AugAssign => aug_assign_variant,
}
}
}
impl HasOpInfo for Unaryop {
fn op_info(&self) -> OpInfo {
match self {
Unaryop::UAdd => make_op_info!("pos", "+"),
Unaryop::USub => make_op_info!("neg", "-"),
Unaryop::Not => make_op_info!("not", "not"), // i.e., `not False`, so the symbol is just `not`.
Unaryop::Invert => make_op_info!("inv", "~"),
}
}
}
pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F) pub(super) fn with_fields<F>(unifier: &mut Unifier, ty: Type, f: F)
where where
F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>), F: FnOnce(&mut Unifier, &mut HashMap<StrRef, (Type, bool)>),
@ -83,26 +168,31 @@ where
pub fn impl_binop( pub fn impl_binop(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, _store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
ops: &[ast::Operator], ops: &[Operator],
) { ) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 { let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None) (other_ty[0], None)
} else { } else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id)) (tvar.ty, Some(tvar.id))
}; };
let function_vars = if let Some(var_id) = other_var_id { let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<HashMap<_, _>>() vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else { } else {
HashMap::new() VarMap::new()
}; };
for op in ops {
fields.insert(binop_name(op).into(), { let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) {
let op = Binop { base: *base_op, variant };
fields.insert(op.op_info().method_name.into(), {
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -111,21 +201,7 @@ pub fn impl_binop(
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
}], is_vararg: false,
})),
false,
)
});
fields.insert(binop_assign_name(op).into(), {
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.none,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
name: "other".into(),
}], }],
})), })),
false, false,
@ -135,15 +211,17 @@ pub fn impl_binop(
}); });
} }
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[ast::Unaryop]) { pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for op in ops {
fields.insert( fields.insert(
unaryop_name(op).into(), op.op_info().method_name.into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
vars: HashMap::new(), vars: VarMap::new(),
args: vec![], args: vec![],
})), })),
false, false,
@ -155,23 +233,40 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[ast::U
pub fn impl_cmpop( pub fn impl_cmpop(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, _store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: Type, other_ty: &[Type],
ops: &[ast::Cmpop], ops: &[Cmpop],
ret_ty: Option<Type>,
) { ) {
with_fields(unifier, ty, |unifier, fields| { with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(tvar.ty, Some(tvar.id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty);
for op in ops { for op in ops {
fields.insert( fields.insert(
comparison_name(op).unwrap().into(), op.op_info().method_name.into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.bool, ret: ret_ty,
vars: HashMap::new(), vars: function_vars.clone(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other_ty,
default_value: None, default_value: None,
name: "other".into(), name: "other".into(),
is_vararg: false,
}], }],
})), })),
false, false,
@ -181,13 +276,13 @@ pub fn impl_cmpop(
}); });
} }
/// Add, Sub, Mult /// `Add`, `Sub`, `Mult`
pub fn impl_basic_arithmetic( pub fn impl_basic_arithmetic(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop( impl_binop(
unifier, unifier,
@ -195,94 +290,407 @@ pub fn impl_basic_arithmetic(
ty, ty,
other_ty, other_ty,
ret_ty, ret_ty,
&[ast::Operator::Add, ast::Operator::Sub, ast::Operator::Mult], &[Operator::Add, Operator::Sub, Operator::Mult],
) );
} }
/// Pow /// `Pow`
pub fn impl_pow( pub fn impl_pow(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Pow]) impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
} }
/// BitOr, BitXor, BitAnd /// `BitOr`, `BitXor`, `BitAnd`
pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop( impl_binop(
unifier, unifier,
store, store,
ty, ty,
&[ty], &[ty],
ty, Some(ty),
&[ast::Operator::BitAnd, ast::Operator::BitOr, ast::Operator::BitXor], &[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
) );
} }
/// LShift, RShift /// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift]) impl_binop(
unifier,
store,
ty,
&[store.int32, store.uint32],
Some(ty),
&[Operator::LShift, Operator::RShift],
);
} }
/// Div /// `Div`
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { pub fn impl_div(
impl_binop(unifier, store, ty, other_ty, store.float, &[ast::Operator::Div]) unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
} }
/// FloorDiv /// `FloorDiv`
pub fn impl_floordiv( pub fn impl_floordiv(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::FloorDiv]) impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
} }
/// Mod /// `Mod`
pub fn impl_mod( pub fn impl_mod(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[ast::Operator::Mod]) impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
} }
/// UAdd, USub /// [`Operator::MatMult`]
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_matmul(
impl_unaryop(unifier, ty, ty, &[ast::Unaryop::UAdd, ast::Unaryop::USub]) unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]);
} }
/// Invert /// `UAdd`, `USub`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) { pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ty, &[ast::Unaryop::Invert]) impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
} }
/// Not /// `Invert`
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, store.bool, &[ast::Unaryop::Not]) impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
} }
/// Lt, LtE, Gt, GtE /// `Not`
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) { pub fn impl_not(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
}
/// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop( impl_cmpop(
unifier, unifier,
store, store,
ty, ty,
other_ty, other_ty,
&[ast::Cmpop::Lt, ast::Cmpop::Gt, ast::Cmpop::LtE, ast::Cmpop::GtE], &[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
) ret_ty,
);
} }
/// Eq, NotEq /// `Eq`, `NotEq`
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_eq(
impl_cmpop(unifier, store, ty, ty, &[ast::Cmpop::Eq, ast::Cmpop::NotEq]) unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
}
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
pub fn typeof_ndarray_broadcast(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
left: Type,
right: Type,
) -> Result<Type, String> {
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
assert!(is_left_ndarray || is_right_ndarray);
if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let res_ndims = left_ty_ndims
.into_iter()
.cartesian_product(right_ty_ndims)
.map(|(left, right)| {
let left_val = u64::try_from(left).unwrap();
let right_val = u64::try_from(right).unwrap();
max(left_val, right_val)
})
.unique()
.map(SymbolValue::U64)
.collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty)
} else {
let (expected_ty, actual_ty) = if is_left_ndarray {
(ndarray_ty_dtype, scalar_ty)
} else {
(scalar_ty, ndarray_ty_dtype)
};
Err(format!(
"Expected right-hand side operand to be {}, got {}",
unifier.stringify(expected_ty),
unifier.stringify(actual_ty),
))
}
}
}
/// Returns the return type given a binary operator and its primitive operands.
pub fn typeof_binop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: Operator,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let op = Binop { base: op, variant: BinopVariant::Normal };
let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(match op.base {
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
if is_left_list || is_right_list {
if ![Operator::Add, Operator::Mult].contains(&op.base) {
return Err(format!(
"Binary operator {} not supported for list",
op.op_info().symbol
));
}
if is_left_list {
lhs
} else {
rhs
}
} else if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None);
}
}
Operator::MatMult => {
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
(
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
{
// LHS and RHS have valid types
}
_ => {
let lhs_str = unifier.stringify(lhs);
let rhs_str = unifier.stringify(rhs);
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
}
}
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs);
let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) {
TypeEnum::TLiteral { values, .. } => {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
}
_ => unreachable!(),
};
match (lhs_ndims, rhs_ndims) {
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
u8::from(rhs == 0)
))
}
(lhs, rhs) => {
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
}
}
Operator::Div => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
primitives.float
} else {
return Ok(None);
}
}
Operator::Pow => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if [
primitives.int32,
primitives.int64,
primitives.uint32,
primitives.uint64,
primitives.float,
]
.into_iter()
.any(|ty| unifier.unioned(lhs, ty))
{
lhs
} else {
return Ok(None);
}
}
Operator::LShift | Operator::RShift => lhs,
Operator::BitOr | Operator::BitXor | Operator::BitAnd => {
if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None);
}
}
}))
}
pub fn typeof_unaryop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: Unaryop,
operand: Type,
) -> Result<Option<Type>, String> {
let operand_obj_id = operand.obj_id(unifier);
if op == Unaryop::Not
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
{
return Err(
"The truth value of an array with more than one element is ambiguous".to_string()
);
}
Ok(match op {
Unaryop::Not => match operand_obj_id {
Some(v) if v == PrimDef::NDArray.id() => Some(operand),
Some(_) => Some(primitives.bool),
_ => None,
},
Unaryop::Invert => {
if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
}
}
Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
});
}
Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
}
}
})
}
/// Returns the return type given a comparison operator and its primitive operands.
pub fn typeof_cmpop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
_op: Cmpop,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
return Ok(None);
}))
} }
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
@ -293,38 +701,81 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t, bool: bool_t,
uint32: uint32_t, uint32: uint32_t,
uint64: uint64_t, uint64: uint64_t,
str: str_t,
list: list_t,
ndarray: ndarray_t,
.. ..
} = *store; } = *store;
let size_t = store.usize();
/* int ======== */ /* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] { for t in [int32_t, int64_t, uint32_t, uint64_t] {
impl_basic_arithmetic(unifier, store, t, &[t], t); let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
impl_pow(unifier, store, t, &[t], t); impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t); impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t]); impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t], t); impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t], t); impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t); impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t); impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, t); impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
impl_eq(unifier, store, t); impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
} }
for t in [int32_t, int64_t] { for t in [int32_t, int64_t] {
impl_sign(unifier, store, t); impl_sign(unifier, store, t, Some(t));
} }
/* float ======== */ /* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
impl_div(unifier, store, float_t, &[float_t]); impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t], float_t); impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t], float_t); impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t); impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_not(unifier, store, float_t); impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_comparison(unifier, store, float_t, float_t); impl_sign(unifier, store, float_t, Some(float_t));
impl_eq(unifier, store, float_t); impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
/* bool ======== */ /* bool ======== */
impl_not(unifier, store, bool_t); let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_eq(unifier, store, bool_t); impl_invert(unifier, store, bool_t, Some(int32_t));
impl_not(unifier, store, bool_t, Some(bool_t));
impl_sign(unifier, store, bool_t, Some(int32_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* str ========= */
impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* list ======== */
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
impl_cmpop(unifier, store, list_t, &[list_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
/* ndarray ===== */
let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.ty));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(
unifier,
store,
ndarray_t,
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
None,
);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t));
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
} }

View File

@ -4,4 +4,3 @@ pub mod type_error;
pub mod type_inferencer; pub mod type_inferencer;
pub mod typedef; pub mod typedef;
mod unification_table; mod unification_table;
pub mod escape_analysis;

View File

@ -1,24 +1,46 @@
use std::collections::HashMap; use std::{collections::HashMap, fmt::Display};
use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum; use itertools::Itertools;
use super::typedef::{RecordKey, Type, Unifier}; use nac3parser::ast::{Cmpop, Location, StrRef};
use nac3parser::ast::{Location, StrRef};
use super::{
magic_methods::Binop,
typedef::{RecordKey, Type, Unifier},
};
use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TypeErrorKind { pub enum TypeErrorKind {
TooManyArguments { GotMultipleValues {
expected: usize, name: StrRef,
got: usize, },
TooManyArguments {
expected_min_count: usize,
expected_max_count: usize,
got_count: usize,
},
MissingArgs {
missing_arg_names: Vec<StrRef>,
}, },
MissingArgs(String),
UnknownArgName(StrRef), UnknownArgName(StrRef),
IncorrectArgType { IncorrectArgType {
name: StrRef, name: StrRef,
expected: Type, expected: Type,
got: Type, got: Type,
}, },
UnsupportedBinaryOpTypes {
operator: Binop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
UnsupportedComparsionOpTypes {
operator: Cmpop,
lhs_type: Type,
rhs_type: Type,
expected_rhs_type: Type,
},
FieldUnificationError { FieldUnificationError {
field: RecordKey, field: RecordKey,
types: (Type, Type), types: (Type, Type),
@ -34,6 +56,7 @@ pub enum TypeErrorKind {
}, },
RequiresTypeAnn, RequiresTypeAnn,
PolymorphicFunctionPointer, PolymorphicFunctionPointer,
NoSuchAttribute(RecordKey, Type),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -43,15 +66,18 @@ pub struct TypeError {
} }
impl TypeError { impl TypeError {
#[must_use]
pub fn new(kind: TypeErrorKind, loc: Option<Location>) -> TypeError { pub fn new(kind: TypeErrorKind, loc: Option<Location>) -> TypeError {
TypeError { kind, loc } TypeError { kind, loc }
} }
#[must_use]
pub fn at(mut self, loc: Option<Location>) -> TypeError { pub fn at(mut self, loc: Option<Location>) -> TypeError {
self.loc = self.loc.or(loc); self.loc = self.loc.or(loc);
self self
} }
#[must_use]
pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError { pub fn to_display(self, unifier: &Unifier) -> DisplayTypeError {
DisplayTypeError { err: self, unifier } DisplayTypeError { err: self, unifier }
} }
@ -64,8 +90,8 @@ pub struct DisplayTypeError<'a> {
fn loc_to_str(loc: Option<Location>) -> String { fn loc_to_str(loc: Option<Location>) -> String {
match loc { match loc {
Some(loc) => format!("(in {})", loc), Some(loc) => format!("(in {loc})"),
None => "".to_string(), None => String::new(),
} }
} }
@ -74,23 +100,49 @@ impl<'a> Display for DisplayTypeError<'a> {
use TypeErrorKind::*; use TypeErrorKind::*;
let mut notes = Some(HashMap::new()); let mut notes = Some(HashMap::new());
match &self.err.kind { match &self.err.kind {
TooManyArguments { expected, got } => { GotMultipleValues { name } => {
write!(f, "Too many arguments. Expected {} but got {}", expected, got) write!(f, "For multiple values for parameter {name}")
} }
MissingArgs(args) => { TooManyArguments { expected_min_count, expected_max_count, got_count } => {
write!(f, "Missing arguments: {}", args) debug_assert!(expected_min_count <= expected_max_count);
if expected_min_count == expected_max_count {
let expected_count = expected_min_count; // or expected_max_count
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
} else {
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
}
}
MissingArgs { missing_arg_names } => {
let args = missing_arg_names.iter().join(", ");
write!(f, "Missing arguments: {args}")
}
UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
}
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
let op_symbol = operator.op_info().symbol;
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
let expected_rhs_type_str =
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
} }
UnknownArgName(name) => { UnknownArgName(name) => {
write!(f, "Unknown argument name: {}", name) write!(f, "Unknown argument name: {name}")
} }
IncorrectArgType { name, expected, got } => { IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!( write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
f,
"Incorrect argument type for {}. Expected {}, but got {}",
name, expected, got
)
} }
FieldUnificationError { field, types, loc } => { FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
@ -126,22 +178,23 @@ impl<'a> Display for DisplayTypeError<'a> {
); );
if let Some(loc) = loc { if let Some(loc) = loc {
result?; result?;
write!(f, " (in {})", loc)?; write!(f, " (in {loc})")?;
return Ok(()); return Ok(());
} }
result result
} }
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) (
if ty1.len() != ty2.len() => TypeEnum::TTuple { ty: ty1, is_vararg_ctx: is_vararg1 },
{ TypeEnum::TTuple { ty: ty2, is_vararg_ctx: is_vararg2 },
) if !is_vararg1 && !is_vararg2 && ty1.len() != ty2.len() => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Tuple length mismatch: got {} and {}", t1, t2) write!(f, "Tuple length mismatch: got {t1} and {t2}")
} }
_ => { _ => {
let t1 = self.unifier.stringify_with_notes(*t1, &mut notes); let t1 = self.unifier.stringify_with_notes(*t1, &mut notes);
let t2 = self.unifier.stringify_with_notes(*t2, &mut notes); let t2 = self.unifier.stringify_with_notes(*t2, &mut notes);
write!(f, "Incompatible types: {} and {}", t1, t2) write!(f, "Incompatible types: {t1} and {t2}")
} }
} }
} }
@ -150,18 +203,21 @@ impl<'a> Display for DisplayTypeError<'a> {
write!(f, "Cannot assign to an element of a tuple") write!(f, "Cannot assign to an element of a tuple")
} else { } else {
let t = self.unifier.stringify_with_notes(*t, &mut notes); let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "Cannot assign to field {} of {}, which is immutable", name, t) write!(f, "Cannot assign to field {name} of {t}, which is immutable")
} }
} }
NoSuchField(name, t) => { NoSuchField(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes); let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{}::{}` field/method does not exist", t, name) write!(f, "`{t}::{name}` field/method does not exist")
}
NoSuchAttribute(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` is not a class attribute")
} }
TupleIndexOutOfBounds { index, len } => { TupleIndexOutOfBounds { index, len } => {
write!( write!(
f, f,
"Tuple index out of bounds. Got {} but tuple has only {} elements", "Tuple index out of bounds. Got {index} but tuple has only {len} elements"
index, len
) )
} }
RequiresTypeAnn => { RequiresTypeAnn => {
@ -172,13 +228,13 @@ impl<'a> Display for DisplayTypeError<'a> {
} }
}?; }?;
if let Some(loc) = self.err.loc { if let Some(loc) = self.err.loc {
write!(f, " at {}", loc)?; write!(f, " at {loc}")?;
} }
let notes = notes.unwrap(); let notes = notes.unwrap();
if !notes.is_empty() { if !notes.is_empty() {
write!(f, "\n\nNotes:")?; write!(f, "\n\nNotes:")?;
for line in notes.values() { for line in notes.values() {
write!(f, "\n {}", line)?; write!(f, "\n {line}")?;
} }
} }
Ok(()) Ok(())

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,21 @@
use super::super::{magic_methods::with_fields, typedef::*}; use std::iter::zip;
use super::*;
use crate::{ use indexmap::IndexMap;
codegen::CodeGenContext,
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
};
use indoc::indoc; use indoc::indoc;
use itertools::zip;
use nac3parser::parser::parse_program;
use parking_lot::RwLock; use parking_lot::RwLock;
use test_case::test_case; use test_case::test_case;
pub(crate) struct Resolver { use nac3parser::{ast::FileName, parser::parse_program};
use super::*;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
symbol_resolver::ValueEnum,
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
typecheck::{magic_methods::with_fields, typedef::*},
};
struct Resolver {
id_to_type: HashMap<StrRef, Type>, id_to_type: HashMap<StrRef, Type>,
id_to_def: HashMap<StrRef, DefinitionId>, id_to_def: HashMap<StrRef, DefinitionId>,
class_names: HashMap<StrRef, Type>, class_names: HashMap<StrRef, Type>,
@ -20,7 +24,7 @@ pub(crate) struct Resolver {
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_default_param_value( fn get_default_param_value(
&self, &self,
_: &nac3parser::ast::Expr, _: &ast::Expr,
) -> Option<crate::symbol_resolver::SymbolValue> { ) -> Option<crate::symbol_resolver::SymbolValue> {
unimplemented!() unimplemented!()
} }
@ -32,19 +36,23 @@ impl SymbolResolver for Resolver {
_: &PrimitiveStore, _: &PrimitiveStore,
str: StrRef, str: StrRef,
) -> Result<Type, String> { ) -> Result<Type, String> {
self.id_to_type.get(&str).cloned().ok_or_else(|| format!("cannot find symbol `{}`", str)) self.id_to_type.get(&str).copied().ok_or_else(|| format!("cannot find symbol `{str}`"))
} }
fn get_symbol_value<'ctx, 'a>( fn get_symbol_value<'ctx>(
&self, &self,
_: StrRef, _: StrRef,
_: &mut CodeGenContext<'ctx, 'a>, _: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
unimplemented!() unimplemented!()
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, String> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def.get(&id).cloned().ok_or_else(|| "Unknown identifier".to_string()) self.id_to_def
.get(&id)
.copied()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
} }
fn get_string_id(&self, _: &str) -> i32 { fn get_string_id(&self, _: &str) -> i32 {
@ -56,13 +64,13 @@ impl SymbolResolver for Resolver {
} }
} }
pub(crate) struct TestEnvironment { struct TestEnvironment {
pub unifier: Unifier, pub unifier: Unifier,
pub function_data: FunctionData, pub function_data: FunctionData,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, StrRef>, pub id_to_name: HashMap<usize, StrRef>,
pub identifier_mapping: HashMap<StrRef, Type>, pub identifier_mapping: HashMap<StrRef, Type>,
pub virtual_checks: Vec<(Type, Type, nac3parser::ast::Location)>, pub virtual_checks: Vec<(Type, Type, Location)>,
pub calls: HashMap<CodeLocation, CallId>, pub calls: HashMap<CodeLocation, CallId>,
pub top_level: TopLevelContext, pub top_level: TopLevelContext,
} }
@ -72,67 +80,86 @@ impl TestEnvironment {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: HashMap::new(), vars: VarMap::new(),
})); }));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: PrimDef::Str.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7), obj_id: PrimDef::Exception.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(8), obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(9), obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(10), obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -146,10 +173,14 @@ impl TestEnvironment {
uint32, uint32,
uint64, uint64,
option, option,
list,
ndarray,
size_t: 64,
}; };
unifier.put_primitive_store(&primitives);
set_primitives_magic_methods(&primitives, &mut unifier); set_primitives_magic_methods(&primitives, &mut unifier);
let id_to_name = [ let id_to_name: HashMap<_, _> = [
(0, "int32".into()), (0, "int32".into()),
(1, "int64".into()), (1, "int64".into()),
(2, "float".into()), (2, "float".into()),
@ -159,23 +190,21 @@ impl TestEnvironment {
(6, "str".into()), (6, "str".into()),
(7, "exception".into()), (7, "exception".into()),
] ]
.iter() .into();
.cloned()
.collect();
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
let resolver = Arc::new(Resolver { let resolver = Arc::new(Resolver {
id_to_type: identifier_mapping.clone(), id_to_type: identifier_mapping.clone(),
id_to_def: Default::default(), id_to_def: HashMap::default(),
class_names: Default::default(), class_names: HashMap::default(),
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
TestEnvironment { TestEnvironment {
top_level: TopLevelContext { top_level: TopLevelContext {
definitions: Default::default(), definitions: Arc::default(),
unifiers: Default::default(), unifiers: Arc::default(),
personality_symbol: None, personality_symbol: None,
}, },
unifier, unifier,
@ -192,86 +221,117 @@ impl TestEnvironment {
} }
} }
pub fn new() -> TestEnvironment { fn new() -> TestEnvironment {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new(); let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
with_fields(&mut unifier, int32, |unifier, fields| { with_fields(&mut unifier, int32, |unifier, fields| {
let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let add_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { name: "other".into(), ty: int32, default_value: None }], args: vec![FuncArg {
name: "other".into(),
ty: int32,
default_value: None,
is_vararg: false,
}],
ret: int32, ret: int32,
vars: HashMap::new(), vars: VarMap::new(),
})); }));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(4), obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(6), obj_id: PrimDef::Str.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(7), obj_id: PrimDef::Exception.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(8), obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(9), obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}); });
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(10), obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
});
let list_elem_tvar = unifier.get_fresh_var(Some("list_elem".into()), None);
let list = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([list_elem_tvar]),
});
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::new(),
}); });
identifier_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] for (i, name) in [
.iter() "int32",
.enumerate() "int64",
"float",
"bool",
"none",
"range",
"str",
"Exception",
"uint32",
"uint64",
"Option",
"list",
"ndarray",
]
.iter()
.enumerate()
{ {
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: (*name).into(), name: (*name).into(),
object_id: DefinitionId(i), object_id: DefinitionId(i),
type_vars: Default::default(), type_vars: Vec::default(),
fields: Default::default(), fields: Vec::default(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -279,7 +339,7 @@ impl TestEnvironment {
.into(), .into(),
); );
} }
let defs = 7; let defs = 12;
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -293,23 +353,29 @@ impl TestEnvironment {
uint32, uint32,
uint64, uint64,
option, option,
list,
ndarray,
size_t: 64,
}; };
let (v0, id) = unifier.get_dummy_var(); unifier.put_primitive_store(&primitives);
let tvar = unifier.get_dummy_var();
let foo_ty = unifier.add_ty(TypeEnum::TObj { let foo_ty = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 1), obj_id: DefinitionId(defs + 1),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(), fields: [("a".into(), (tvar.ty, true))].into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(), params: into_var_map([tvar]),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Foo".into(), name: "Foo".into(),
object_id: DefinitionId(defs + 1), object_id: DefinitionId(defs + 1),
type_vars: vec![v0], type_vars: vec![tvar.ty],
fields: [("a".into(), v0, true)].into(), fields: [("a".into(), tvar.ty, true)].into(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -322,31 +388,29 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: foo_ty, ret: foo_ty,
vars: [(id, v0)].iter().cloned().collect(), vars: into_var_map([tvar]),
})), })),
); );
let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature { let fun = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: int32, ret: int32,
vars: Default::default(), vars: IndexMap::default(),
})); }));
let bar = unifier.add_ty(TypeEnum::TObj { let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 2), obj_id: DefinitionId(defs + 2),
fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))] fields: [("a".into(), (int32, true)), ("b".into(), (fun, true))].into(),
.iter() params: IndexMap::default(),
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Bar".into(), name: "Bar".into(),
object_id: DefinitionId(defs + 2), object_id: DefinitionId(defs + 2),
type_vars: Default::default(), type_vars: Vec::default(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(), fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -358,26 +422,24 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: bar, ret: bar,
vars: Default::default(), vars: IndexMap::default(),
})), })),
); );
let bar2 = unifier.add_ty(TypeEnum::TObj { let bar2 = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 3), obj_id: DefinitionId(defs + 3),
fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))] fields: [("a".into(), (bool, true)), ("b".into(), (fun, false))].into(),
.iter() params: IndexMap::default(),
.cloned()
.collect::<HashMap<_, _>>(),
params: Default::default(),
}); });
top_level_defs.push( top_level_defs.push(
RwLock::new(TopLevelDef::Class { RwLock::new(TopLevelDef::Class {
name: "Bar2".into(), name: "Bar2".into(),
object_id: DefinitionId(defs + 3), object_id: DefinitionId(defs + 3),
type_vars: Default::default(), type_vars: Vec::default(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(), fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
methods: Default::default(), attributes: Vec::default(),
ancestors: Default::default(), methods: Vec::default(),
ancestors: Vec::default(),
resolver: None, resolver: None,
constructor: None, constructor: None,
loc: None, loc: None,
@ -389,10 +451,10 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: bar2, ret: bar2,
vars: Default::default(), vars: IndexMap::default(),
})), })),
); );
let class_names = [("Bar".into(), bar), ("Bar2".into(), bar2)].iter().cloned().collect(); let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into();
let id_to_name = [ let id_to_name = [
"int32".into(), "int32".into(),
@ -403,18 +465,22 @@ impl TestEnvironment {
"range".into(), "range".into(),
"str".into(), "str".into(),
"exception".into(), "exception".into(),
"uint32".into(),
"uint64".into(),
"option".into(),
"list".into(),
"ndarray".into(),
"Foo".into(), "Foo".into(),
"Bar".into(), "Bar".into(),
"Bar2".into(), "Bar2".into(),
] ]
.iter() .into_iter()
.enumerate() .enumerate()
.map(|(a, b)| (a, *b))
.collect(); .collect();
let top_level = TopLevelContext { let top_level = TopLevelContext {
definitions: Arc::new(top_level_defs.into()), definitions: Arc::new(top_level_defs.into()),
unifiers: Default::default(), unifiers: Arc::default(),
personality_symbol: None, personality_symbol: None,
}; };
@ -425,9 +491,7 @@ impl TestEnvironment {
("Bar".into(), DefinitionId(defs + 2)), ("Bar".into(), DefinitionId(defs + 2)),
("Bar2".into(), DefinitionId(defs + 3)), ("Bar2".into(), DefinitionId(defs + 3)),
] ]
.iter() .into(),
.cloned()
.collect(),
class_names, class_names,
}) as Arc<dyn SymbolResolver + Send + Sync>; }) as Arc<dyn SymbolResolver + Send + Sync>;
@ -447,16 +511,16 @@ impl TestEnvironment {
} }
} }
pub fn get_inferencer(&mut self) -> Inferencer { fn get_inferencer(&mut self) -> Inferencer {
Inferencer { Inferencer {
top_level: &self.top_level, top_level: &self.top_level,
function_data: &mut self.function_data, function_data: &mut self.function_data,
unifier: &mut self.unifier, unifier: &mut self.unifier,
variable_mapping: Default::default(), variable_mapping: HashMap::default(),
primitives: &mut self.primitives, primitives: &mut self.primitives,
virtual_checks: &mut self.virtual_checks, virtual_checks: &mut self.virtual_checks,
calls: &mut self.calls, calls: &mut self.calls,
defined_identifiers: Default::default(), defined_identifiers: HashMap::default(),
in_handler: false, in_handler: false,
} }
} }
@ -468,7 +532,7 @@ impl TestEnvironment {
c = 1.234 c = 1.234
d = True d = True
"}, "},
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect(), &[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].into(),
&[] &[]
; "primitives test")] ; "primitives test")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -477,7 +541,7 @@ impl TestEnvironment {
c = 1.234 c = 1.234
d = b(c) d = b(c)
"}, "},
[("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect(), &[("a", "fn[[x:float, y:float], float]"), ("b", "fn[[x:float], float]"), ("c", "float"), ("d", "float")].into(),
&[] &[]
; "lambda test")] ; "lambda test")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -486,7 +550,7 @@ impl TestEnvironment {
a = b a = b
c = b(1) c = b(1)
"}, "},
[("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].iter().cloned().collect(), &[("a", "fn[[x:int32], int32]"), ("b", "fn[[x:int32], int32]"), ("c", "int32")].into(),
&[] &[]
; "lambda test 2")] ; "lambda test 2")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -502,15 +566,15 @@ impl TestEnvironment {
b(123) b(123)
"}, "},
[("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"), &[("a", "fn[[x:bool], bool]"), ("b", "fn[[x:int32], int32]"), ("c", "bool"),
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect(), ("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].into(),
&[] &[]
; "obj test")] ; "obj test")]
#[test_case(indoc! {" #[test_case(indoc! {"
a = [1, 2, 3] a = [1, 2, 3]
b = [x + x for x in a] b = [x + x for x in a]
"}, "},
[("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(), &[("a", "list[int32]"), ("b", "list[int32]")].into(),
&[] &[]
; "listcomp test")] ; "listcomp test")]
#[test_case(indoc! {" #[test_case(indoc! {"
@ -518,25 +582,26 @@ impl TestEnvironment {
b = a.b() b = a.b()
a = virtual(Bar2()) a = virtual(Bar2())
"}, "},
[("a", "virtual[Bar]"), ("b", "int32")].iter().cloned().collect(), &[("a", "virtual[Bar]"), ("b", "int32")].into(),
&[("Bar", "Bar"), ("Bar2", "Bar")] &[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual test")] ; "virtual test")]
#[test_case(indoc! {" #[test_case(indoc! {"
a = [virtual(Bar(), Bar), virtual(Bar2())] a = [virtual(Bar(), Bar), virtual(Bar2())]
b = [x.b() for x in a] b = [x.b() for x in a]
"}, "},
[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].iter().cloned().collect(), &[("a", "list[virtual[Bar]]"), ("b", "list[int32]")].into(),
&[("Bar", "Bar"), ("Bar2", "Bar")] &[("Bar", "Bar"), ("Bar2", "Bar")]
; "virtual list test")] ; "virtual list test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &str)]) { fn test_basic(source: &str, mapping: &HashMap<&str, &str>, virtuals: &[(&str, &str)]) {
println!("source:\n{}", source); println!("source:\n{source}");
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: HashMap<_, _> =
defined_identifiers.insert("virtual".into()); env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone(); inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, Default::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let statements = statements let statements = statements
.into_iter() .into_iter()
.map(|v| inferencer.fold_stmt(v)) .map(|v| inferencer.fold_stmt(v))
@ -545,37 +610,37 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() { for (k, v) in &inferencer.variable_mapping {
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*v, *v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
println!("{}: {}", k, name); println!("{k}: {name}");
} }
for (k, v) in mapping.iter() { for (k, v) in mapping {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*ty, *ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{k}: {v}"), format!("{k}: {name}"));
} }
assert_eq!(inferencer.virtual_checks.len(), virtuals.len()); assert_eq!(inferencer.virtual_checks.len(), virtuals.len());
for ((a, b, _), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) { for ((a, b, _), (x, y)) in zip(inferencer.virtual_checks.iter(), virtuals) {
let a = inferencer.unifier.internal_stringify( let a = inferencer.unifier.internal_stringify(
*a, *a,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
let b = inferencer.unifier.internal_stringify( let b = inferencer.unifier.internal_stringify(
*b, *b,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
@ -594,14 +659,14 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
g = a // b g = a // b
h = a % b h = a % b
"}, "},
[("a", "int32"), &[("a", "int32"),
("b", "int32"), ("b", "int32"),
("c", "int32"), ("c", "int32"),
("d", "int32"), ("d", "int32"),
("e", "int32"), ("e", "int32"),
("f", "float"), ("f", "float"),
("g", "int32"), ("g", "int32"),
("h", "int32")].iter().cloned().collect() ("h", "int32")].into()
; "int32")] ; "int32")]
#[test_case( #[test_case(
indoc! {" indoc! {"
@ -617,7 +682,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
ii = 3 ii = 3
j = a ** b j = a ** b
"}, "},
[("a", "float"), &[("a", "float"),
("b", "float"), ("b", "float"),
("c", "float"), ("c", "float"),
("d", "float"), ("d", "float"),
@ -627,7 +692,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
("h", "float"), ("h", "float"),
("i", "float"), ("i", "float"),
("ii", "int32"), ("ii", "int32"),
("j", "float")].iter().cloned().collect() ("j", "float")].into()
; "float" ; "float"
)] )]
#[test_case( #[test_case(
@ -645,7 +710,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
k = a < b k = a < b
l = a != b l = a != b
"}, "},
[("a", "int64"), &[("a", "int64"),
("b", "int64"), ("b", "int64"),
("c", "int64"), ("c", "int64"),
("d", "int64"), ("d", "int64"),
@ -656,7 +721,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
("i", "bool"), ("i", "bool"),
("j", "bool"), ("j", "bool"),
("k", "bool"), ("k", "bool"),
("l", "bool")].iter().cloned().collect() ("l", "bool")].into()
; "int64" ; "int64"
)] )]
#[test_case( #[test_case(
@ -667,22 +732,23 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st
d = not a d = not a
e = a != b e = a != b
"}, "},
[("a", "bool"), &[("a", "bool"),
("b", "bool"), ("b", "bool"),
("c", "bool"), ("c", "bool"),
("d", "bool"), ("d", "bool"),
("e", "bool")].iter().cloned().collect() ("e", "bool")].into()
; "boolean" ; "boolean"
)] )]
fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) { fn test_primitive_magic_methods(source: &str, mapping: &HashMap<&str, &str>) {
println!("source:\n{}", source); println!("source:\n{source}");
let mut env = TestEnvironment::basic_test_env(); let mut env = TestEnvironment::basic_test_env();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); let mut defined_identifiers: HashMap<_, _> =
defined_identifiers.insert("virtual".into()); env.identifier_mapping.keys().copied().map(|id| (id, IdentifierInfo::default())).collect();
defined_identifiers.insert("virtual".into(), IdentifierInfo::default());
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
inferencer.defined_identifiers = defined_identifiers.clone(); inferencer.defined_identifiers.clone_from(&defined_identifiers);
let statements = parse_program(source, Default::default()).unwrap(); let statements = parse_program(source, FileName::default()).unwrap();
let statements = statements let statements = statements
.into_iter() .into_iter()
.map(|v| inferencer.fold_stmt(v)) .map(|v| inferencer.fold_stmt(v))
@ -691,23 +757,23 @@ fn test_primitive_magic_methods(source: &str, mapping: HashMap<&str, &str>) {
inferencer.check_block(&statements, &mut defined_identifiers).unwrap(); inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() { for (k, v) in &inferencer.variable_mapping {
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*v, *v,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
println!("{}: {}", k, name); println!("{k}: {name}");
} }
for (k, v) in mapping.iter() { for (k, v) in mapping {
let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap(); let ty = inferencer.variable_mapping.get(&(*k).into()).unwrap();
let name = inferencer.unifier.internal_stringify( let name = inferencer.unifier.internal_stringify(
*ty, *ty,
&mut |v| (*id_to_name.get(&v).unwrap()).into(), &mut |v| (*id_to_name.get(&v).unwrap()).into(),
&mut |v| format!("v{}", v), &mut |v| format!("v{v}"),
&mut None, &mut None,
); );
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name)); assert_eq!(format!("{k}: {v}"), format!("{k}: {name}"));
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,12 @@
use super::super::magic_methods::with_fields; use std::collections::HashMap;
use super::*;
use indoc::indoc; use indoc::indoc;
use itertools::Itertools; use itertools::Itertools;
use std::collections::HashMap;
use test_case::test_case; use test_case::test_case;
use super::*;
use crate::typecheck::magic_methods::with_fields;
impl Unifier { impl Unifier {
/// Check whether two types are equal. /// Check whether two types are equal.
fn eq(&mut self, a: Type, b: Type) -> bool { fn eq(&mut self, a: Type, b: Type) -> bool {
@ -28,32 +30,32 @@ impl Unifier {
TypeEnum::TVar { fields: Some(map1), .. }, TypeEnum::TVar { fields: Some(map1), .. },
TypeEnum::TVar { fields: Some(map2), .. }, TypeEnum::TVar { fields: Some(map2), .. },
) => self.map_eq2(map1, map2), ) => self.map_eq2(map1, map2),
(TypeEnum::TTuple { ty: ty1 }, TypeEnum::TTuple { ty: ty2 }) => { (
TypeEnum::TTuple { ty: ty1, is_vararg_ctx: false },
TypeEnum::TTuple { ty: ty2, is_vararg_ctx: false },
) => {
ty1.len() == ty2.len() ty1.len() == ty2.len()
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
} }
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 }) (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => self.eq(*ty1, *ty2),
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
self.eq(*ty1, *ty2)
}
( (
TypeEnum::TObj { obj_id: id1, params: params1, .. }, TypeEnum::TObj { obj_id: id1, params: params1, .. },
TypeEnum::TObj { obj_id: id2, params: params2, .. }, TypeEnum::TObj { obj_id: id2, params: params2, .. },
) => id1 == id2 && self.map_eq(params1, params2), ) => id1 == id2 && self.map_eq(params1, params2),
// TCall and TFunc are not yet implemented // TLiteral, TCall and TFunc are not yet implemented
_ => false, _ => false,
} }
} }
fn map_eq<K>(&mut self, map1: &Mapping<K>, map2: &Mapping<K>) -> bool fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, K: std::hash::Hash + Eq + Clone,
{ {
if map1.len() != map2.len() { if map1.len() != map2.len() {
return false; return false;
} }
for (k, v) in map1.iter() { for (k, v) in map1 {
if !map2.get(k).map(|v1| self.eq(*v, *v1)).unwrap_or(false) { if !map2.get(k).is_some_and(|v1| self.eq(*v, *v1)) {
return false; return false;
} }
} }
@ -62,13 +64,13 @@ impl Unifier {
fn map_eq2<K>(&mut self, map1: &Mapping<K, RecordField>, map2: &Mapping<K, RecordField>) -> bool fn map_eq2<K>(&mut self, map1: &Mapping<K, RecordField>, map2: &Mapping<K, RecordField>) -> bool
where where
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone, K: std::hash::Hash + Eq + Clone,
{ {
if map1.len() != map2.len() { if map1.len() != map2.len() {
return false; return false;
} }
for (k, v) in map1.iter() { for (k, v) in map1 {
if !map2.get(k).map(|v1| self.eq(v.ty, v1.ty)).unwrap_or(false) { if !map2.get(k).is_some_and(|v1| self.eq(v.ty, v1.ty)) {
return false; return false;
} }
} }
@ -91,7 +93,7 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(0), obj_id: DefinitionId(0),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}), }),
); );
type_mapping.insert( type_mapping.insert(
@ -99,7 +101,7 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(1), obj_id: DefinitionId(1),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}), }),
); );
type_mapping.insert( type_mapping.insert(
@ -107,16 +109,25 @@ impl TestEnvironment {
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(2), obj_id: DefinitionId(2),
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: VarMap::new(),
}), }),
); );
let (v0, id) = unifier.get_dummy_var(); let tvar = unifier.get_dummy_var();
type_mapping.insert( type_mapping.insert(
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TObj { unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(3), obj_id: DefinitionId(3),
fields: [("a".into(), (v0, true))].iter().cloned().collect::<HashMap<_, _>>(), fields: [("a".into(), (tvar.ty, true))].into(),
params: [(id, v0)].iter().cloned().collect::<HashMap<_, _>>(), params: into_var_map([tvar]),
}),
);
let tvar = unifier.get_dummy_var();
type_mapping.insert(
"list".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::List.id(),
fields: HashMap::new(),
params: into_var_map([tvar]),
}), }),
); );
@ -129,34 +140,54 @@ impl TestEnvironment {
result.0 result.0
} }
fn internal_parse<'a, 'b>( fn internal_parse<'b>(&mut self, typ: &'b str, mapping: &Mapping<String>) -> (Type, &'b str) {
&'a mut self,
typ: &'b str,
mapping: &Mapping<String>,
) -> (Type, &'b str) {
// for testing only, so we can just panic when the input is malformed // for testing only, so we can just panic when the input is malformed
let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or_else(|| typ.len()); let end = typ.find(|c| ['[', ',', ']', '='].contains(&c)).unwrap_or(typ.len());
match &typ[..end] { match &typ[..end] {
"tuple" => { "list" => {
let mut s = &typ[end..]; let mut s = &typ[end..];
assert!(&s[0..1] == "["); assert_eq!(&s[0..1], "[");
let mut ty = Vec::new(); let mut ty = Vec::new();
while &s[0..1] != "]" { while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping); let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0); ty.push(result.0);
s = result.1; s = result.1;
} }
(self.unifier.add_ty(TypeEnum::TTuple { ty }), &s[1..])
assert_eq!(ty.len(), 1);
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*self.unifier.get_ty_immutable(self.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
(
self.unifier
.subst(
self.type_mapping["list"],
&into_var_map([TypeVar { id: list_elem_tvar.id, ty: ty[0] }]),
)
.unwrap(),
&s[1..],
)
} }
"list" => { "tuple" => {
assert!(&typ[end..end + 1] == "["); let mut s = &typ[end..];
let (ty, s) = self.internal_parse(&typ[end + 1..], mapping); assert_eq!(&s[0..1], "[");
assert!(&s[0..1] == "]"); let mut ty = Vec::new();
(self.unifier.add_ty(TypeEnum::TList { ty }), &s[1..]) while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping);
ty.push(result.0);
s = result.1;
}
(self.unifier.add_ty(TypeEnum::TTuple { ty, is_vararg_ctx: false }), &s[1..])
} }
"Record" => { "Record" => {
let mut s = &typ[end..]; let mut s = &typ[end..];
assert!(&s[0..1] == "["); assert_eq!(&s[0..1], "[");
let mut fields = HashMap::new(); let mut fields = HashMap::new();
while &s[0..1] != "]" { while &s[0..1] != "]" {
let eq = s.find('=').unwrap(); let eq = s.find('=').unwrap();
@ -169,14 +200,14 @@ impl TestEnvironment {
} }
x => { x => {
let mut s = &typ[end..]; let mut s = &typ[end..];
let ty = mapping.get(x).cloned().unwrap_or_else(|| { let ty = mapping.get(x).copied().unwrap_or_else(|| {
// mapping should be type variables, type_mapping should be concrete types // mapping should be type variables, type_mapping should be concrete types
// we should not resolve the type of type variables. // we should not resolve the type of type variables.
let mut ty = *self.type_mapping.get(x).unwrap(); let mut ty = *self.type_mapping.get(x).unwrap();
let te = self.unifier.get_ty(ty); let te = self.unifier.get_ty(ty);
if let TypeEnum::TObj { params, .. } = &*te.as_ref() { if let TypeEnum::TObj { params, .. } = &*te {
if !params.is_empty() { if !params.is_empty() {
assert!(&s[0..1] == "["); assert_eq!(&s[0..1], "[");
let mut p = Vec::new(); let mut p = Vec::new();
while &s[0..1] != "]" { while &s[0..1] != "]" {
let result = self.internal_parse(&s[1..], mapping); let result = self.internal_parse(&s[1..], mapping);
@ -186,7 +217,7 @@ impl TestEnvironment {
s = &s[1..]; s = &s[1..];
ty = self ty = self
.unifier .unifier
.subst(ty, &params.keys().cloned().zip(p.into_iter()).collect()) .subst(ty, &params.keys().copied().zip(p).collect())
.unwrap_or(ty); .unwrap_or(ty);
} }
} }
@ -250,12 +281,12 @@ fn test_unify(
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.unifier.get_dummy_var(); let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0); mapping.insert(format!("v{i}"), v.ty);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.
let mut pairs = Vec::new(); let mut pairs = Vec::new();
for (a, b) in perm.iter() { for (a, b) in &perm {
let t1 = env.parse(a, &mapping); let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping); let t2 = env.parse(b, &mapping);
pairs.push((t1, t2)); pairs.push((t1, t2));
@ -263,8 +294,8 @@ fn test_unify(
for (t1, t2) in pairs { for (t1, t2) in pairs {
env.unifier.unify(t1, t2).unwrap(); env.unifier.unify(t1, t2).unwrap();
} }
for (a, b) in verify_pairs.iter() { for (a, b) in verify_pairs {
println!("{} = {}", a, b); println!("{a} = {b}");
let t1 = env.parse(a, &mapping); let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping); let t2 = env.parse(b, &mapping);
println!("a = {}, b = {}", env.unifier.stringify(t1), env.unifier.stringify(t2)); println!("a = {}, b = {}", env.unifier.stringify(t1), env.unifier.stringify(t2));
@ -278,7 +309,7 @@ fn test_unify(
("v1", "tuple[int]"), ("v1", "tuple[int]"),
("v2", "list[int]"), ("v2", "list[int]"),
], ],
(("v1", "v2"), "Incompatible types: list[0] and tuple[0]") (("v1", "v2"), "Incompatible types: 11[0] and tuple[0]")
; "type mismatch" ; "type mismatch"
)] )]
#[test_case(2, #[test_case(2,
@ -286,7 +317,7 @@ fn test_unify(
("v1", "tuple[int]"), ("v1", "tuple[int]"),
("v2", "tuple[float]"), ("v2", "tuple[float]"),
], ],
(("v1", "v2"), "Incompatible types: 0 and 1") (("v1", "v2"), "Incompatible types: tuple[0] and tuple[1]")
; "tuple parameter mismatch" ; "tuple parameter mismatch"
)] )]
#[test_case(2, #[test_case(2,
@ -302,35 +333,35 @@ fn test_unify(
("v1", "Record[a=float,b=int]"), ("v1", "Record[a=float,b=int]"),
("v2", "Foo[v3]"), ("v2", "Foo[v3]"),
], ],
(("v1", "v2"), "`3[var4]::b` field/method does not exist") (("v1", "v2"), "`3[typevar5]::b` field/method does not exist")
; "record obj merge" ; "record obj merge"
)] )]
/// Test cases for invalid unifications. /// Test cases for invalid unifications.
fn test_invalid_unification( fn test_invalid_unification(
variable_count: u32, variable_count: u32,
unify_pairs: &[(&'static str, &'static str)], unify_pairs: &[(&'static str, &'static str)],
errornous_pair: ((&'static str, &'static str), &'static str), erroneous_pair: ((&'static str, &'static str), &'static str),
) { ) {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let mut mapping = HashMap::new(); let mut mapping = HashMap::new();
for i in 1..=variable_count { for i in 1..=variable_count {
let v = env.unifier.get_dummy_var(); let v = env.unifier.get_dummy_var();
mapping.insert(format!("v{}", i), v.0); mapping.insert(format!("v{i}"), v.ty);
} }
// unification may have side effect when we do type resolution, so freeze the types // unification may have side effect when we do type resolution, so freeze the types
// before doing unification. // before doing unification.
let mut pairs = Vec::new(); let mut pairs = Vec::new();
for (a, b) in unify_pairs.iter() { for (a, b) in unify_pairs {
let t1 = env.parse(a, &mapping); let t1 = env.parse(a, &mapping);
let t2 = env.parse(b, &mapping); let t2 = env.parse(b, &mapping);
pairs.push((t1, t2)); pairs.push((t1, t2));
} }
let (t1, t2) = let (t1, t2) =
(env.parse(errornous_pair.0 .0, &mapping), env.parse(errornous_pair.0 .1, &mapping)); (env.parse(erroneous_pair.0 .0, &mapping), env.parse(erroneous_pair.0 .1, &mapping));
for (a, b) in pairs { for (a, b) in pairs {
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
} }
assert_eq!(env.unify(t1, t2), Err(errornous_pair.1.to_string())); assert_eq!(env.unify(t1, t2), Err(erroneous_pair.1.to_string()));
} }
#[test] #[test]
@ -339,23 +370,17 @@ fn test_recursive_subst() {
let int = *env.type_mapping.get("int").unwrap(); let int = *env.type_mapping.get("int").unwrap();
let foo_id = *env.type_mapping.get("Foo").unwrap(); let foo_id = *env.type_mapping.get("Foo").unwrap();
let foo_ty = env.unifier.get_ty(foo_id); let foo_ty = env.unifier.get_ty(foo_id);
let mapping: HashMap<_, _>;
with_fields(&mut env.unifier, foo_id, |_unifier, fields| { with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
fields.insert("rec".into(), (foo_id, true)); fields.insert("rec".into(), (foo_id, true));
}); });
if let TypeEnum::TObj { params, .. } = &*foo_ty { let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() };
mapping = params.iter().map(|(id, _)| (*id, int)).collect(); let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
} else {
unreachable!()
}
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap(); let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated); let instantiated_ty = env.unifier.get_ty(instantiated);
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() };
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
} else { assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
unreachable!()
}
} }
#[test] #[test]
@ -365,36 +390,27 @@ fn test_virtual() {
let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature { let fun = env.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: int, ret: int,
vars: HashMap::new(), vars: VarMap::new(),
})); }));
let bar = env.unifier.add_ty(TypeEnum::TObj { let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),
fields: [("f".into(), (fun, false)), ("a".into(), (int, false))] fields: [("f".into(), (fun, false)), ("a".into(), (int, false))].into(),
.iter() params: VarMap::new(),
.cloned()
.collect::<HashMap<StrRef, _>>(),
params: HashMap::new(),
}); });
let v0 = env.unifier.get_dummy_var().0; let v0 = env.unifier.get_dummy_var().ty;
let v1 = env.unifier.get_dummy_var().0; let v1 = env.unifier.get_dummy_var().ty;
let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar }); let a = env.unifier.add_ty(TypeEnum::TVirtual { ty: bar });
let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 }); let b = env.unifier.add_ty(TypeEnum::TVirtual { ty: v0 });
let c = env let c = env.unifier.add_record([("f".into(), RecordField::new(v1, false, None))].into());
.unifier
.add_record([("f".into(), RecordField::new(v1, false, None))].iter().cloned().collect());
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
env.unifier.unify(b, c).unwrap(); env.unifier.unify(b, c).unwrap();
assert!(env.unifier.eq(v1, fun)); assert!(env.unifier.eq(v1, fun));
let d = env let d = env.unifier.add_record([("a".into(), RecordField::new(v1, true, None))].into());
.unifier
.add_record([("a".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field/method does not exist".to_string())); assert_eq!(env.unify(b, d), Err("`virtual[5]::a` field/method does not exist".to_string()));
let d = env let d = env.unifier.add_record([("b".into(), RecordField::new(v1, true, None))].into());
.unifier
.add_record([("b".into(), RecordField::new(v1, true, None))].iter().cloned().collect());
assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field/method does not exist".to_string())); assert_eq!(env.unify(b, d), Err("`virtual[5]::b` field/method does not exist".to_string()));
} }
@ -407,85 +423,132 @@ fn test_typevar_range() {
let int_list = env.parse("list[int]", &HashMap::new()); let int_list = env.parse("list[int]", &HashMap::new());
let float_list = env.parse("list[float]", &HashMap::new()); let float_list = env.parse("list[float]", &HashMap::new());
let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
&*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
// unification between v and int // unification between v and int
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
env.unifier.unify(int, v).unwrap(); env.unifier.unify(int, v).unwrap();
// unification between v and list[int] // unification between v and list[int]
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
assert_eq!( assert_eq!(
env.unify(int_list, v), env.unify(int_list, v),
Err("Expected any one of these types: 0, 2, but got list[0]".to_string()) Err("Expected any one of these types: 0, 2, but got 11[0]".to_string())
); );
// unification between v and float // unification between v and float
// where v in (int, bool) // where v in (int, bool)
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
assert_eq!( assert_eq!(
env.unify(float, v), env.unify(float, v),
Err("Expected any one of these types: 0, 2, but got 1".to_string()) Err("Expected any one of these types: 0, 2, but got 1".to_string())
); );
let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let v1 = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v1_list = env.unifier.add_ty(TypeEnum::TList { ty: v1 }); let v1_list = env.unifier.add_ty(TypeEnum::TObj {
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: v1 }]),
});
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and int // unification between v and int
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int, v).unwrap(); env.unifier.unify(int, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and list[int] // unification between v and list[int]
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
env.unifier.unify(int_list, v).unwrap(); env.unifier.unify(int_list, v).unwrap();
let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, v1_list], None, None).ty;
// unification between v and list[float] // unification between v and list[float]
// where v in (int, list[v1]), v1 in (int, bool) // where v in (int, list[v1]), v1 in (int, bool)
println!("float_list: {}, v: {}", env.unifier.stringify(float_list), env.unifier.stringify(v));
assert_eq!( assert_eq!(
env.unify(float_list, v), env.unify(float_list, v),
Err("Expected any one of these types: 0, list[var5], but got list[1]\n\nNotes:\n var5 ∈ {0, 2}".to_string()) Err("Expected any one of these types: 0, 11[typevar6], but got 11[1]\n\nNotes:\n typevar6 ∈ {0, 2}".to_string())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
env.unifier.unify(a, float).unwrap(); env.unifier.unify(a, float).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
env.unifier.unify(a, b).unwrap(); env.unifier.unify(a, b).unwrap();
assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into())); assert_eq!(env.unify(a, int), Err("Expected any one of these types: 1, but got 0".into()));
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TObj {
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); fields: Mapping::default(),
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).0; params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
let b_list = env.unifier.get_fresh_var_with_range(&[b_list], None, None).ty;
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
let float_list = env.unifier.add_ty(TypeEnum::TList { ty: float }); let float_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: float }]),
});
env.unifier.unify(a_list, float_list).unwrap(); env.unifier.unify(a_list, float_list).unwrap();
// previous unifications should not affect a and b // previous unifications should not affect a and b
env.unifier.unify(a, int).unwrap(); env.unifier.unify(a, int).unwrap();
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).0; let b = env.unifier.get_fresh_var_with_range(&[boolean, float], None, None).ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TObj {
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
let int_list = env.unifier.add_ty(TypeEnum::TList { ty: int }); let int_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: int }]),
});
assert_eq!( assert_eq!(
env.unify(a_list, int_list), env.unify(a_list, int_list),
Err("Expected any one of these types: 1, but got 0".into()) Err("Incompatible types: 11[typevar23] and 11[0]\
\n\nNotes:\n typevar23 {1}"
.into())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).ty;
let b = env.unifier.get_dummy_var().0; let b = env.unifier.get_dummy_var().ty;
let a_list = env.unifier.add_ty(TypeEnum::TList { ty: a }); let a_list = env.unifier.add_ty(TypeEnum::TObj {
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).0; obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
let b_list = env.unifier.add_ty(TypeEnum::TList { ty: b }); fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let a_list = env.unifier.get_fresh_var_with_range(&[a_list], None, None).ty;
let b_list = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: b }]),
});
env.unifier.unify(a_list, b_list).unwrap(); env.unifier.unify(a_list, b_list).unwrap();
assert_eq!( assert_eq!(
env.unify(b, boolean), env.unify(b, boolean),
@ -496,17 +559,29 @@ fn test_typevar_range() {
#[test] #[test]
fn test_rigid_var() { fn test_rigid_var() {
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let a = env.unifier.get_fresh_rigid_var(None, None).0; let a = env.unifier.get_fresh_rigid_var(None, None).ty;
let b = env.unifier.get_fresh_rigid_var(None, None).0; let b = env.unifier.get_fresh_rigid_var(None, None).ty;
let x = env.unifier.get_dummy_var().0; let x = env.unifier.get_dummy_var().ty;
let list_a = env.unifier.add_ty(TypeEnum::TList { ty: a }); let list_elem_tvar = env.unifier.get_fresh_var(Some("list_elem".into()), None);
let list_x = env.unifier.add_ty(TypeEnum::TList { ty: x }); let list_a = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: a }]),
});
let list_x = env.unifier.add_ty(TypeEnum::TObj {
obj_id: env.type_mapping["list"].obj_id(&env.unifier).unwrap(),
fields: Mapping::default(),
params: into_var_map([TypeVar { id: list_elem_tvar.id, ty: x }]),
});
let int = env.parse("int", &HashMap::new()); let int = env.parse("int", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new()); let list_int = env.parse("list[int]", &HashMap::new());
assert_eq!(env.unify(a, b), Err("Incompatible types: var3 and var2".to_string())); assert_eq!(env.unify(a, b), Err("Incompatible types: typevar4 and typevar3".to_string()));
env.unifier.unify(list_a, list_x).unwrap(); env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: 0 and var2".to_string())); assert_eq!(
env.unify(list_x, list_int),
Err("Incompatible types: 11[typevar3] and 11[0]".to_string())
);
env.unifier.replace_rigid_var(a, int); env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap(); env.unifier.unify(list_x, list_int).unwrap();
@ -520,16 +595,26 @@ fn test_instantiation() {
let float = env.parse("float", &HashMap::new()); let float = env.parse("float", &HashMap::new());
let list_int = env.parse("list[int]", &HashMap::new()); let list_int = env.parse("list[int]", &HashMap::new());
let obj_map: HashMap<_, _> = let list_elem_tvar = if let TypeEnum::TObj { params, .. } =
[(0usize, "int"), (1, "float"), (2, "bool")].iter().cloned().collect(); &*env.unifier.get_ty_immutable(env.type_mapping["list"])
{
iter_type_vars(params).next().unwrap()
} else {
unreachable!()
};
let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).0; let obj_map: HashMap<_, _> = [(0usize, "int"), (1, "float"), (2, "bool"), (11, "list")].into();
let list_v = env.unifier.add_ty(TypeEnum::TList { ty: v });
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).0; let v = env.unifier.get_fresh_var_with_range(&[int, boolean], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).0; let list_v = env
let t = env.unifier.get_dummy_var().0; .unifier
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2] }); .subst(env.type_mapping["list"], &into_var_map([TypeVar { id: list_elem_tvar.id, ty: v }]))
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).0; .unwrap();
let v1 = env.unifier.get_fresh_var_with_range(&[list_v, int], None, None).ty;
let v2 = env.unifier.get_fresh_var_with_range(&[list_int, float], None, None).ty;
let t = env.unifier.get_dummy_var().ty;
let tuple = env.unifier.add_ty(TypeEnum::TTuple { ty: vec![v, v1, v2], is_vararg_ctx: false });
let v3 = env.unifier.get_fresh_var_with_range(&[tuple, t], None, None).ty;
// t = TypeVar('t') // t = TypeVar('t')
// v = TypeVar('v', int, bool) // v = TypeVar('v', int, bool)
// v1 = TypeVar('v1', 'list[v]', int) // v1 = TypeVar('v1', 'list[v]', int)
@ -551,7 +636,7 @@ fn test_instantiation() {
tuple[int, list[bool], list[int]] tuple[int, list[bool], list[int]]
tuple[int, list[int], float] tuple[int, list[int], float]
tuple[int, list[int], list[int]] tuple[int, list[int], list[int]]
v5" v6"
} }
.split('\n') .split('\n')
.collect_vec(); .collect_vec();
@ -560,8 +645,8 @@ fn test_instantiation() {
.map(|ty| { .map(|ty| {
env.unifier.internal_stringify( env.unifier.internal_stringify(
*ty, *ty,
&mut |i| obj_map.get(&i).unwrap().to_string(), &mut |i| (*obj_map.get(&i).unwrap()).to_string(),
&mut |i| format!("v{}", i), &mut |i| format!("v{i}"),
&mut None, &mut None,
) )
}) })

View File

@ -3,7 +3,7 @@ use std::rc::Rc;
use itertools::izip; use itertools::izip;
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct UnificationKey(pub(crate) usize); pub struct UnificationKey(usize);
#[derive(Clone)] #[derive(Clone)]
pub struct UnificationTable<V> { pub struct UnificationTable<V> {
@ -16,21 +16,10 @@ pub struct UnificationTable<V> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum Action<V> { enum Action<V> {
Parent { Parent { key: usize, original_parent: usize },
key: usize, Value { key: usize, original_value: Option<V> },
original_parent: usize, Rank { key: usize, original_rank: u32 },
}, Marker { generation: u32 },
Value {
key: usize,
original_value: Option<V>,
},
Rank {
key: usize,
original_rank: u32,
},
Marker {
generation: u32,
}
} }
impl<V> Default for UnificationTable<V> { impl<V> Default for UnificationTable<V> {
@ -41,12 +30,12 @@ impl<V> Default for UnificationTable<V> {
impl<V> UnificationTable<V> { impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> { pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } UnificationTable {
} parents: Vec::new(),
ranks: Vec::new(),
fn log_action(&mut self, action: Action<V>) { values: Vec::new(),
if !self.log.is_empty() { log: Vec::new(),
self.log.push(action); generation: 0,
} }
} }
@ -67,10 +56,10 @@ impl<V> UnificationTable<V> {
if self.ranks[a] < self.ranks[b] { if self.ranks[a] < self.ranks[b] {
std::mem::swap(&mut a, &mut b); std::mem::swap(&mut a, &mut b);
} }
self.log_action(Action::Parent { key: b, original_parent: self.parents[b] }); self.log.push(Action::Parent { key: b, original_parent: self.parents[b] });
self.parents[b] = a; self.parents[b] = a;
if self.ranks[a] == self.ranks[b] { if self.ranks[a] == self.ranks[b] {
self.log_action(Action::Rank { key: a, original_rank: self.ranks[a] }); self.log.push(Action::Rank { key: a, original_rank: self.ranks[a] });
self.ranks[a] += 1; self.ranks[a] += 1;
} }
} }
@ -94,7 +83,7 @@ impl<V> UnificationTable<V> {
pub fn set_value(&mut self, a: UnificationKey, v: V) { pub fn set_value(&mut self, a: UnificationKey, v: V) {
let index = self.find(a); let index = self.find(a);
let original_value = self.values[index].replace(v); let original_value = self.values[index].replace(v);
self.log_action(Action::Value { key: index, original_value }); self.log.push(Action::Value { key: index, original_value });
} }
pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool { pub fn unioned(&mut self, a: UnificationKey, b: UnificationKey) -> bool {
@ -112,7 +101,7 @@ impl<V> UnificationTable<V> {
// a = parent.parent // a = parent.parent
let a = self.parents[parent]; let a = self.parents[parent];
// root.parent = parent.parent // root.parent = parent.parent
self.log_action(Action::Parent { key: root, original_parent: self.parents[root] }); self.log.push(Action::Parent { key: root, original_parent: self.parents[root] });
self.parents[root] = a; self.parents[root] = a;
root = parent; root = parent;
// parent = root.parent // parent = root.parent
@ -131,7 +120,10 @@ impl<V> UnificationTable<V> {
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) { pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot; let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot restoration error"); assert!(self.log.len() >= log_len, "snapshot restoration error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error"); assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot restoration error"
);
for action in self.log.drain(log_len - 1..).rev() { for action in self.log.drain(log_len - 1..).rev() {
match action { match action {
Action::Parent { key, original_parent } => { Action::Parent { key, original_parent } => {
@ -151,7 +143,10 @@ impl<V> UnificationTable<V> {
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) { pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot; let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error"); assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error"); assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot discard error"
);
self.log.clear(); self.log.clear();
} }
} }
@ -165,11 +160,23 @@ where
.enumerate() .enumerate()
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
.collect(); .collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 } UnificationTable {
parents: self.parents.clone(),
ranks: self.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
} }
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> { pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 } UnificationTable {
parents: table.parents.clone(),
ranks: table.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
} }
} }

8
nac3ld/Cargo.toml Normal file
View File

@ -0,0 +1,8 @@
[package]
name = "nac3ld"
version = "0.1.0"
authors = ["M-Labs"]
edition = "2021"
[dependencies]
byteorder = { version = "1.5", default-features = false }

509
nac3ld/src/dwarf.rs Normal file
View File

@ -0,0 +1,509 @@
#![allow(non_camel_case_types, non_upper_case_globals)]
use std::mem;
use byteorder::{ByteOrder, LittleEndian};
pub const DW_EH_PE_omit: u8 = 0xFF;
pub const DW_EH_PE_absptr: u8 = 0x00;
pub const DW_EH_PE_uleb128: u8 = 0x01;
pub const DW_EH_PE_udata2: u8 = 0x02;
pub const DW_EH_PE_udata4: u8 = 0x03;
pub const DW_EH_PE_udata8: u8 = 0x04;
pub const DW_EH_PE_sleb128: u8 = 0x09;
pub const DW_EH_PE_sdata2: u8 = 0x0A;
pub const DW_EH_PE_sdata4: u8 = 0x0B;
pub const DW_EH_PE_sdata8: u8 = 0x0C;
pub const DW_EH_PE_pcrel: u8 = 0x10;
pub const DW_EH_PE_textrel: u8 = 0x20;
pub const DW_EH_PE_datarel: u8 = 0x30;
pub const DW_EH_PE_funcrel: u8 = 0x40;
pub const DW_EH_PE_aligned: u8 = 0x50;
pub const DW_EH_PE_indirect: u8 = 0x80;
pub struct DwarfReader<'a> {
pub slice: &'a [u8],
pub virt_addr: u32,
base_slice: &'a [u8],
base_virt_addr: u32,
}
impl<'a> DwarfReader<'a> {
pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader {
DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr }
}
/// Creates a new instance from another instance of [DwarfReader], optionally removing any
/// offsets previously applied to the other instance.
pub fn from_reader(other: &DwarfReader<'a>, reset_offset: bool) -> DwarfReader<'a> {
if reset_offset {
DwarfReader::new(other.base_slice, other.base_virt_addr)
} else {
DwarfReader::new(other.slice, other.virt_addr)
}
}
pub fn offset(&mut self, offset: u32) {
self.slice = &self.slice[offset as usize..];
self.virt_addr = self.virt_addr.wrapping_add(offset);
}
/// ULEB128 and SLEB128 encodings are defined in Section 7.6 - "Variable Length Data" of the
/// [DWARF-4 Manual](https://dwarfstd.org/doc/DWARF4.pdf).
pub fn read_uleb128(&mut self) -> u64 {
let mut shift: usize = 0;
let mut result: u64 = 0;
let mut byte: u8;
loop {
byte = self.read_u8();
result |= u64::from(byte & 0x7F) << shift;
shift += 7;
if byte & 0x80 == 0 {
break;
}
}
result
}
pub fn read_sleb128(&mut self) -> i64 {
let mut shift: u32 = 0;
let mut result: u64 = 0;
let mut byte: u8;
loop {
byte = self.read_u8();
result |= u64::from(byte & 0x7F) << shift;
shift += 7;
if byte & 0x80 == 0 {
break;
}
}
// sign-extend
if shift < u64::BITS && (byte & 0x40) != 0 {
result |= (!0u64) << shift;
}
result as i64
}
pub fn read_u8(&mut self) -> u8 {
let val = self.slice[0];
self.slice = &self.slice[1..];
val
}
}
macro_rules! impl_read_fn {
( $($type: ty, $byteorder_fn: ident);* ) => {
impl<'a> DwarfReader<'a> {
$(
pub fn $byteorder_fn(&mut self) -> $type {
let val = LittleEndian::$byteorder_fn(self.slice);
self.slice = &self.slice[mem::size_of::<$type>()..];
val
}
)*
}
}
}
impl_read_fn!(
u16, read_u16;
u32, read_u32;
u64, read_u64;
i16, read_i16;
i32, read_i32;
i64, read_i64
);
pub struct DwarfWriter<'a> {
pub slice: &'a mut [u8],
pub offset: usize,
}
impl<'a> DwarfWriter<'a> {
pub fn new(slice: &mut [u8]) -> DwarfWriter {
DwarfWriter { slice, offset: 0 }
}
pub fn write_u8(&mut self, data: u8) {
self.slice[self.offset] = data;
self.offset += 1;
}
pub fn write_u32(&mut self, data: u32) {
LittleEndian::write_u32(&mut self.slice[self.offset..], data);
self.offset += 4;
}
}
fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
if encoding == DW_EH_PE_omit {
return Err(());
}
// DW_EH_PE_aligned implies it's an absolute pointer value
// However, we are linking library for 32-bits architecture
// The size of variable should be 4 bytes instead
if encoding == DW_EH_PE_aligned {
let shifted_virt_addr = round_up(reader.virt_addr as usize, mem::size_of::<u32>())?;
let addr_inc = shifted_virt_addr - reader.virt_addr as usize;
reader.slice = &reader.slice[addr_inc..];
reader.virt_addr = shifted_virt_addr as u32;
return Ok(reader.read_u32() as usize);
}
match encoding & 0x0F {
DW_EH_PE_absptr | DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize),
DW_EH_PE_udata2 => Ok(reader.read_u16() as usize),
DW_EH_PE_udata8 => Ok(reader.read_u64() as usize),
DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize),
DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize),
DW_EH_PE_sdata4 => Ok(reader.read_i32() as usize),
DW_EH_PE_sdata8 => Ok(reader.read_i64() as usize),
_ => Err(()),
}
}
fn read_encoded_pointer_with_pc(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
let entry_virt_addr = reader.virt_addr;
let mut result = read_encoded_pointer(reader, encoding)?;
// DW_EH_PE_aligned implies it's an absolute pointer value
if encoding == DW_EH_PE_aligned {
return Ok(result);
}
result = match encoding & 0x70 {
DW_EH_PE_pcrel => result.wrapping_add(entry_virt_addr as usize),
// .eh_frame normally would not have these kinds of relocations
// These would not be supported by a dedicated linker relocation schemes for RISC-V
DW_EH_PE_textrel | DW_EH_PE_datarel | DW_EH_PE_funcrel | DW_EH_PE_aligned => {
unimplemented!()
}
// Other values should be impossible
_ => unreachable!(),
};
if encoding & DW_EH_PE_indirect != 0 {
// There should not be a need for indirect addressing, as assembly code from
// the dynamic library should not be freely moved relative to the EH frame.
unreachable!()
}
Ok(result)
}
#[inline]
fn round_up(unrounded: usize, align: usize) -> Result<usize, ()> {
if align.is_power_of_two() {
Ok((unrounded + align - 1) & !(align - 1))
} else {
Err(())
}
}
/// Minimalistic structure to store everything needed for parsing FDEs to synthesize `.eh_frame_hdr`
/// section.
///
/// Refer to [The Linux Standard Base Core Specification, Generic Part](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html)
/// for more information.
pub struct EH_Frame<'a> {
reader: DwarfReader<'a>,
}
impl<'a> EH_Frame<'a> {
/// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF
/// file.
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> EH_Frame {
EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) }
}
/// Returns an [Iterator] over all Call Frame Information (CFI) records.
pub fn cfi_records(&self) -> CFI_Records<'a> {
let reader = DwarfReader::from_reader(&self.reader, true);
let len = reader.slice.len();
CFI_Records { reader, available: len }
}
}
/// A single Call Frame Information (CFI) record.
///
/// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html):
///
/// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame
/// > Description Entry (FDE) records.
pub struct CFI_Record<'a> {
// It refers to the augmentation data that corresponds to 'R' in the augmentation string
fde_pointer_encoding: u8,
fde_reader: DwarfReader<'a>,
}
impl<'a> CFI_Record<'a> {
pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> {
let length = cie_reader.read_u32();
let fde_reader = match length {
// eh_frame with 0 lengths means the CIE is terminated
0 => panic!("Cannot create an EH_Frame from a termination CIE"),
// length == u32::MAX means that the length is only representable with 64 bits,
// which does not make sense in a system with 32-bit address.
0xFFFF_FFFF => unimplemented!(),
_ => {
let mut fde_reader = DwarfReader::from_reader(cie_reader, false);
fde_reader.offset(length);
fde_reader
}
};
// Routine check on the .eh_frame well-formness, in terms of CIE ID & Version args.
let cie_ptr = cie_reader.read_u32();
assert_eq!(cie_ptr, 0);
assert_eq!(cie_reader.read_u8(), 1);
// Parse augmentation string
// The first character must be 'z', there is no way to proceed otherwise
assert_eq!(cie_reader.read_u8(), b'z');
// Establish a pointer that skips ahead of the string
// Skip code/data alignment factors & return address register along the way as well
// We only tackle the case where 'z' and 'R' are part of the augmentation string, otherwise
// we cannot get the addresses to make .eh_frame_hdr
let mut aug_data_reader = DwarfReader::from_reader(cie_reader, false);
let mut aug_str_len = 0;
loop {
if aug_data_reader.read_u8() == b'\0' {
break;
}
aug_str_len += 1;
}
if aug_str_len == 0 {
unimplemented!();
}
aug_data_reader.read_uleb128(); // Code alignment factor
aug_data_reader.read_sleb128(); // Data alignment factor
aug_data_reader.read_uleb128(); // Return address register
aug_data_reader.read_uleb128(); // Augmentation data length
let mut fde_pointer_encoding = DW_EH_PE_omit;
for _ in 0..aug_str_len {
match cie_reader.read_u8() {
b'L' => {
aug_data_reader.read_u8();
}
b'P' => {
let encoding = aug_data_reader.read_u8();
read_encoded_pointer(&mut aug_data_reader, encoding)?;
}
b'R' => {
fde_pointer_encoding = aug_data_reader.read_u8();
}
// Other characters are not supported
_ => unimplemented!(),
}
}
assert_ne!(fde_pointer_encoding, DW_EH_PE_omit);
Ok(CFI_Record { fde_pointer_encoding, fde_reader })
}
/// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI
/// record.
pub fn get_fde_reader(&self) -> DwarfReader<'a> {
DwarfReader::from_reader(&self.fde_reader, true)
}
/// Returns an [Iterator] over all Frame Description Entries (FDEs).
pub fn fde_records(&self) -> FDE_Records<'a> {
let reader = self.get_fde_reader();
let len = reader.slice.len();
FDE_Records { pointer_encoding: self.fde_pointer_encoding, reader, available: len }
}
}
/// [Iterator] over Call Frame Information (CFI) records in an
/// [Exception Handling (EH) frame][EH_Frame].
pub struct CFI_Records<'a> {
reader: DwarfReader<'a>,
available: usize,
}
impl<'a> Iterator for CFI_Records<'a> {
type Item = CFI_Record<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.available == 0 {
return None;
}
let mut this_reader = DwarfReader::from_reader(&self.reader, false);
// Remove the length of the header and the content from the counter
let length = self.reader.read_u32();
let length = match length {
// eh_frame with 0-length means the CIE is terminated
0 => return None,
0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other,
} as usize;
// Remove the length of the header and the content from the counter
self.available -= length + mem::size_of::<u32>();
let mut next_reader = DwarfReader::from_reader(&self.reader, false);
next_reader.offset(length as u32);
let cie_ptr = self.reader.read_u32();
self.reader = next_reader;
// Skip this record if it is a FDE
if cie_ptr == 0 {
// Rewind back to the start of the CFI Record
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap());
}
}
}
}
/// [Iterator] over Frame Description Entries (FDEs) in an
/// [Exception Handling (EH) frame][EH_Frame].
pub struct FDE_Records<'a> {
pointer_encoding: u8,
reader: DwarfReader<'a>,
available: usize,
}
impl<'a> Iterator for FDE_Records<'a> {
type Item = (u32, u32);
fn next(&mut self) -> Option<Self::Item> {
// Parse each FDE to obtain the starting address that the FDE applies to
// Send the FDE offset and the mentioned address to a callback that write up the
// .eh_frame_hdr section
if self.available == 0 {
return None;
}
// Remove the length of the header and the content from the counter
let length = match self.reader.read_u32() {
// eh_frame with 0-length means the CIE is terminated
0 => return None,
0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other,
} as usize;
// Remove the length of the header and the content from the counter
self.available -= length + mem::size_of::<u32>();
let mut next_fde_reader = DwarfReader::from_reader(&self.reader, false);
next_fde_reader.offset(length as u32);
let cie_ptr = self.reader.read_u32();
let next_val = if cie_ptr != 0 {
let pc_begin = read_encoded_pointer_with_pc(&mut self.reader, self.pointer_encoding)
.expect("Failed to read PC Begin");
Some((pc_begin as u32, self.reader.virt_addr))
} else {
None
};
self.reader = next_fde_reader;
next_val
}
}
pub struct EH_Frame_Hdr<'a> {
fde_writer: DwarfWriter<'a>,
eh_frame_hdr_addr: u32,
fdes: Vec<(u32, u32)>,
}
impl<'a> EH_Frame_Hdr<'a> {
/// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory.
///
/// Load address is not known at this point.
pub fn new(
eh_frame_hdr_slice: &mut [u8],
eh_frame_hdr_addr: u32,
eh_frame_addr: u32,
) -> EH_Frame_Hdr {
let mut writer = DwarfWriter::new(eh_frame_hdr_slice);
writer.write_u8(1); // version
writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
let eh_frame_offset = eh_frame_addr.wrapping_sub(
eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4),
);
writer.write_u32(eh_frame_offset); // eh_frame_ptr
writer.write_u32(0); // `fde_count`, will be written in finalize_fde
EH_Frame_Hdr { fde_writer: writer, eh_frame_hdr_addr, fdes: Vec::new() }
}
/// The offset of the `fde_count` value relative to the start of the `.eh_frame_hdr` section in
/// bytes.
fn fde_count_offset() -> usize {
8
}
pub fn add_fde(&mut self, init_loc: u32, addr: u32) {
self.fdes.push((
init_loc.wrapping_sub(self.eh_frame_hdr_addr),
addr.wrapping_sub(self.eh_frame_hdr_addr),
));
}
pub fn finalize_fde(mut self) {
self.fdes
.sort_by(|(left_init_loc, _), (right_init_loc, _)| left_init_loc.cmp(right_init_loc));
for (init_loc, addr) in &self.fdes {
self.fde_writer.write_u32(*init_loc);
self.fde_writer.write_u32(*addr);
}
LittleEndian::write_u32(
&mut self.fde_writer.slice[Self::fde_count_offset()..],
self.fdes.len() as u32,
);
}
pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize {
// The virtual address of the EH frame does not matter in this case
// Calculation of size does not involve modifying any headers
let mut reader = DwarfReader::new(eh_frame, 0);
let mut fde_count = 0;
while !reader.slice.is_empty() {
// The original length field should be able to hold the entire value.
// The device memory space is limited to 32-bits addresses anyway.
let entry_length = reader.read_u32();
if entry_length == 0 || entry_length == 0xFFFF_FFFF {
unimplemented!()
}
// This slot stores the CIE ID (for CIE)/CIE Pointer (for FDE).
// This value must be non-zero for FDEs.
let cie_ptr = reader.read_u32();
if cie_ptr != 0 {
fde_count += 1;
}
reader.offset(entry_length - mem::size_of::<u32>() as u32);
}
12 + fde_count * 8
}
}

2893
nac3ld/src/elf.rs Normal file

File diff suppressed because it is too large Load Diff

1502
nac3ld/src/lib.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@ -5,20 +5,20 @@ description = "Parser for python code."
authors = [ "RustPython Team", "M-Labs" ] authors = [ "RustPython Team", "M-Labs" ]
build = "build.rs" build = "build.rs"
license = "MIT" license = "MIT"
edition = "2018" edition = "2021"
[build-dependencies] [build-dependencies]
lalrpop = "0.19.6" lalrpop = "0.22"
[dependencies] [dependencies]
nac3ast = { path = "../nac3ast" } nac3ast = { path = "../nac3ast" }
lalrpop-util = "0.19.6" lalrpop-util = "0.22"
log = "0.4.1" log = "0.4"
unic-emoji-char = "0.9" unic-emoji-char = "0.9"
unic-ucd-ident = "0.9" unic-ucd-ident = "0.9"
unicode_names2 = "0.4" unicode_names2 = "1.3"
phf = { version = "0.9", features = ["macros"] } phf = { version = "0.11", features = ["macros"] }
ahash = "0.7.2" ahash = "0.8"
[dev-dependencies] [dev-dependencies]
insta = "=1.11.0" insta = "=1.11.0"

View File

@ -1,15 +1,17 @@
use crate::{
ast::{Ident, Location},
error::*,
token::Tok,
};
use lalrpop_util::ParseError; use lalrpop_util::ParseError;
use nac3ast::*; use nac3ast::*;
use crate::ast::Ident;
use crate::ast::Location;
use crate::token::Tok;
use crate::error::*;
pub fn make_config_comment( pub fn make_config_comment(
com_loc: Location, com_loc: Location,
stmt_loc: Location, stmt_loc: Location,
nac3com_above: Vec<(Ident, Tok)>, nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident> nac3com_end: Option<Ident>,
) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> { ) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> {
if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() { if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() {
return Err(ParseError::User { return Err(ParseError::User {
@ -17,24 +19,25 @@ pub fn make_config_comment(
location: com_loc, location: com_loc,
error: LexicalErrorType::OtherError( error: LexicalErrorType::OtherError(
format!( format!(
"config comment at top must have the same indentation with what it applies (comment at {}, statement at {})", "config comment at top must have the same indentation with what it applies (comment at {com_loc}, statement at {stmt_loc})",
com_loc,
stmt_loc,
) )
) )
} }
}) });
}; };
Ok( Ok(nac3com_above
nac3com_above .into_iter()
.into_iter() .map(|(com, _)| com)
.map(|(com, _)| com) .chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter())) .collect())
.collect()
)
} }
pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option<Ident>, com_above_loc: Location) -> Result<(), ParseError<Location, Tok, LexicalError>> { pub fn handle_small_stmt<U>(
stmts: &mut [Stmt<U>],
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>,
com_above_loc: Location,
) -> Result<(), ParseError<Location, Tok, LexicalError>> {
if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() { if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() {
return Err(ParseError::User { return Err(ParseError::User {
error: LexicalError { error: LexicalError {
@ -47,17 +50,12 @@ pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, To
) )
) )
} }
}) });
} }
apply_config_comments( apply_config_comments(&mut stmts[0], nac3com_above.into_iter().map(|(com, _)| com).collect());
&mut stmts[0],
nac3com_above
.into_iter()
.map(|(com, _)| com).collect()
);
apply_config_comments( apply_config_comments(
stmts.last_mut().unwrap(), stmts.last_mut().unwrap(),
nac3com_end.map_or_else(Vec::new, |com| vec![com]) nac3com_end.map_or_else(Vec::new, |com| vec![com]),
); );
Ok(()) Ok(())
} }
@ -80,6 +78,8 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
| StmtKind::Nonlocal { config_comment, .. } | StmtKind::Nonlocal { config_comment, .. }
| StmtKind::Assert { config_comment, .. } => config_comment.extend(comments), | StmtKind::Assert { config_comment, .. } => config_comment.extend(comments),
_ => { unreachable!("only small statements should call this function") } _ => {
unreachable!("only small statements should call this function")
}
} }
} }

View File

@ -1,13 +1,12 @@
//! Define internal parse error types //! Define internal parse error types
//! The goal is to provide a matching and a safe error API, maksing errors from LALR //! The goal is to provide a matching and a safe error API, maksing errors from LALR
use lalrpop_util::ParseError as LalrpopError;
use crate::ast::Location;
use crate::token::Tok;
use std::error::Error; use std::error::Error;
use std::fmt; use std::fmt;
use lalrpop_util::ParseError as LalrpopError;
use crate::{ast::Location, token::Tok};
/// Represents an error during lexical scanning. /// Represents an error during lexical scanning.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct LexicalError { pub struct LexicalError {
@ -37,7 +36,7 @@ impl fmt::Display for LexicalErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
LexicalErrorType::StringError => write!(f, "Got unexpected string"), LexicalErrorType::StringError => write!(f, "Got unexpected string"),
LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {}", error), LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {error}"),
LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"), LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"),
LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"), LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"),
LexicalErrorType::IndentationError => { LexicalErrorType::IndentationError => {
@ -59,13 +58,13 @@ impl fmt::Display for LexicalErrorType {
write!(f, "positional argument follows keyword argument") write!(f, "positional argument follows keyword argument")
} }
LexicalErrorType::UnrecognizedToken { tok } => { LexicalErrorType::UnrecognizedToken { tok } => {
write!(f, "Got unexpected token {}", tok) write!(f, "Got unexpected token {tok}")
} }
LexicalErrorType::LineContinuationError => { LexicalErrorType::LineContinuationError => {
write!(f, "unexpected character after line continuation character") write!(f, "unexpected character after line continuation character")
} }
LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"), LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"),
LexicalErrorType::OtherError(msg) => write!(f, "{}", msg), LexicalErrorType::OtherError(msg) => write!(f, "{msg}"),
} }
} }
} }
@ -96,7 +95,7 @@ impl fmt::Display for FStringErrorType {
FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"), FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"),
FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."), FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."),
FStringErrorType::InvalidExpression(error) => { FStringErrorType::InvalidExpression(error) => {
write!(f, "Invalid expression: {}", error) write!(f, "Invalid expression: {error}")
} }
FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"), FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"),
FStringErrorType::EmptyExpression => write!(f, "Empty expression"), FStringErrorType::EmptyExpression => write!(f, "Empty expression"),
@ -144,36 +143,27 @@ pub enum ParseErrorType {
impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError { impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError {
fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self { fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self {
match err { match err {
// TODO: Are there cases where this isn't an EOF? LalrpopError::ExtraToken { token } => {
LalrpopError::InvalidToken { location } => ParseError { ParseError { error: ParseErrorType::ExtraToken(token.1), location: token.0 }
error: ParseErrorType::Eof, }
location, LalrpopError::User { error } => {
}, ParseError { error: ParseErrorType::Lexical(error.error), location: error.location }
LalrpopError::ExtraToken { token } => ParseError { }
error: ParseErrorType::ExtraToken(token.1),
location: token.0,
},
LalrpopError::User { error } => ParseError {
error: ParseErrorType::Lexical(error.error),
location: error.location,
},
LalrpopError::UnrecognizedToken { token, expected } => { LalrpopError::UnrecognizedToken { token, expected } => {
// Hacky, but it's how CPython does it. See PyParser_AddToken, // Hacky, but it's how CPython does it. See PyParser_AddToken,
// in particular "Only one possible expected token" comment. // in particular "Only one possible expected token" comment.
let expected = if expected.len() == 1 { let expected = if expected.len() == 1 { Some(expected[0].clone()) } else { None };
Some(expected[0].clone())
} else {
None
};
ParseError { ParseError {
error: ParseErrorType::UnrecognizedToken(token.1, expected), error: ParseErrorType::UnrecognizedToken(token.1, expected),
location: token.0, location: token.0,
} }
} }
LalrpopError::UnrecognizedEOF { location, .. } => ParseError {
error: ParseErrorType::Eof, LalrpopError::UnrecognizedEof { location, .. }
location, // TODO: Are there cases where this isn't an EOF?
}, | LalrpopError::InvalidToken { location } => {
ParseError { error: ParseErrorType::Eof, location }
}
} }
} }
} }
@ -188,7 +178,7 @@ impl fmt::Display for ParseErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
ParseErrorType::Eof => write!(f, "Got unexpected EOF"), ParseErrorType::Eof => write!(f, "Got unexpected EOF"),
ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {:?}", tok), ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {tok:?}"),
ParseErrorType::InvalidToken => write!(f, "Got invalid token"), ParseErrorType::InvalidToken => write!(f, "Got invalid token"),
ParseErrorType::UnrecognizedToken(ref tok, ref expected) => { ParseErrorType::UnrecognizedToken(ref tok, ref expected) => {
if *tok == Tok::Indent { if *tok == Tok::Indent {
@ -196,10 +186,10 @@ impl fmt::Display for ParseErrorType {
} else if expected.as_deref() == Some("Indent") { } else if expected.as_deref() == Some("Indent") {
write!(f, "expected an indented block") write!(f, "expected an indented block")
} else { } else {
write!(f, "Got unexpected token {}", tok) write!(f, "Got unexpected token {tok}")
} }
} }
ParseErrorType::Lexical(ref error) => write!(f, "{}", error), ParseErrorType::Lexical(ref error) => write!(f, "{error}"),
} }
} }
} }
@ -207,6 +197,7 @@ impl fmt::Display for ParseErrorType {
impl Error for ParseErrorType {} impl Error for ParseErrorType {}
impl ParseErrorType { impl ParseErrorType {
#[must_use]
pub fn is_indentation_error(&self) -> bool { pub fn is_indentation_error(&self) -> bool {
match self { match self {
ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true, ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true,
@ -216,11 +207,11 @@ impl ParseErrorType {
_ => false, _ => false,
} }
} }
#[must_use]
pub fn is_tab_error(&self) -> bool { pub fn is_tab_error(&self) -> bool {
matches!( matches!(
self, self,
ParseErrorType::Lexical(LexicalErrorType::TabError) ParseErrorType::Lexical(LexicalErrorType::TabError | LexicalErrorType::TabsAfterSpaces)
| ParseErrorType::Lexical(LexicalErrorType::TabsAfterSpaces)
) )
} }
} }

View File

@ -1,12 +1,11 @@
use std::iter; use std::{iter, mem, str};
use std::mem;
use std::str;
use crate::ast::{Constant, ConversionFlag, Expr, ExprKind, Location};
use crate::error::{FStringError, FStringErrorType, ParseError};
use crate::parser::parse_expression;
use self::FStringErrorType::*; use self::FStringErrorType::*;
use crate::{
ast::{Constant, ConversionFlag, Expr, ExprKind, Location},
error::{FStringError, FStringErrorType, ParseError},
parser::parse_expression,
};
struct FStringParser<'a> { struct FStringParser<'a> {
chars: iter::Peekable<str::Chars<'a>>, chars: iter::Peekable<str::Chars<'a>>,
@ -15,10 +14,7 @@ struct FStringParser<'a> {
impl<'a> FStringParser<'a> { impl<'a> FStringParser<'a> {
fn new(source: &'a str, str_location: Location) -> Self { fn new(source: &'a str, str_location: Location) -> Self {
Self { Self { chars: source.chars().peekable(), str_location }
chars: source.chars().peekable(),
str_location,
}
} }
#[inline] #[inline]
@ -133,10 +129,10 @@ impl<'a> FStringParser<'a> {
) )
} else { } else {
Box::new(self.expr(ExprKind::Constant { Box::new(self.expr(ExprKind::Constant {
value: spec_expression.to_owned().into(), value: spec_expression.clone().into(),
kind: None, kind: None,
})) }))
}) });
} }
'(' | '{' | '[' => { '(' | '{' | '[' => {
expression.push(ch); expression.push(ch);
@ -251,17 +247,11 @@ impl<'a> FStringParser<'a> {
} }
if !content.is_empty() { if !content.is_empty() {
values.push(self.expr(ExprKind::Constant { values.push(self.expr(ExprKind::Constant { value: content.into(), kind: None }));
value: content.into(),
kind: None,
}))
} }
let s = match values.len() { let s = match values.len() {
0 => self.expr(ExprKind::Constant { 0 => self.expr(ExprKind::Constant { value: String::new().into(), kind: None }),
value: String::new().into(),
kind: None,
}),
1 => values.into_iter().next().unwrap(), 1 => values.into_iter().next().unwrap(),
_ => self.expr(ExprKind::JoinedStr { values }), _ => self.expr(ExprKind::JoinedStr { values }),
}; };
@ -270,16 +260,14 @@ impl<'a> FStringParser<'a> {
} }
fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> { fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> {
let fstring_body = format!("({})", source); let fstring_body = format!("({source})");
parse_expression(&fstring_body) parse_expression(&fstring_body)
} }
/// Parse an fstring from a string, located at a certain position in the sourcecode. /// Parse an fstring from a string, located at a certain position in the sourcecode.
/// In case of errors, we will get the location and the error returned. /// In case of errors, we will get the location and the error returned.
pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> { pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> {
FStringParser::new(source, location) FStringParser::new(source, location).parse().map_err(|error| FStringError { error, location })
.parse()
.map_err(|error| FStringError { error, location })
} }
#[cfg(test)] #[cfg(test)]
@ -293,7 +281,7 @@ mod tests {
#[test] #[test]
fn test_parse_fstring() { fn test_parse_fstring() {
let source = "{a}{ b }{{foo}}"; let source = "{a}{ b }{{foo}}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -301,7 +289,7 @@ mod tests {
#[test] #[test]
fn test_parse_fstring_nested_spec() { fn test_parse_fstring_nested_spec() {
let source = "{foo:{spec}}"; let source = "{foo:{spec}}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -309,7 +297,7 @@ mod tests {
#[test] #[test]
fn test_parse_fstring_not_nested_spec() { fn test_parse_fstring_not_nested_spec() {
let source = "{foo:spec}"; let source = "{foo:spec}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -322,7 +310,7 @@ mod tests {
#[test] #[test]
fn test_fstring_parse_selfdocumenting_base() { fn test_fstring_parse_selfdocumenting_base() {
let src = "{user=}"; let src = "{user=}";
let parse_ast = parse_fstring(&src).unwrap(); let parse_ast = parse_fstring(src).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -330,7 +318,7 @@ mod tests {
#[test] #[test]
fn test_fstring_parse_selfdocumenting_base_more() { fn test_fstring_parse_selfdocumenting_base_more() {
let src = "mix {user=} with text and {second=}"; let src = "mix {user=} with text and {second=}";
let parse_ast = parse_fstring(&src).unwrap(); let parse_ast = parse_fstring(src).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -338,7 +326,7 @@ mod tests {
#[test] #[test]
fn test_fstring_parse_selfdocumenting_format() { fn test_fstring_parse_selfdocumenting_format() {
let src = "{user=:>10}"; let src = "{user=:>10}";
let parse_ast = parse_fstring(&src).unwrap(); let parse_ast = parse_fstring(src).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -371,35 +359,35 @@ mod tests {
#[test] #[test]
fn test_parse_fstring_not_equals() { fn test_parse_fstring_not_equals() {
let source = "{1 != 2}"; let source = "{1 != 2}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_equals() { fn test_parse_fstring_equals() {
let source = "{42 == 42}"; let source = "{42 == 42}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_selfdoc_prec_space() { fn test_parse_fstring_selfdoc_prec_space() {
let source = "{x =}"; let source = "{x =}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_selfdoc_trailing_space() { fn test_parse_fstring_selfdoc_trailing_space() {
let source = "{x= }"; let source = "{x= }";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_fstring_yield_expr() { fn test_parse_fstring_yield_expr() {
let source = "{yield}"; let source = "{yield}";
let parse_ast = parse_fstring(&source).unwrap(); let parse_ast = parse_fstring(source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
} }

View File

@ -1,8 +1,11 @@
use ahash::RandomState;
use std::collections::HashSet; use std::collections::HashSet;
use crate::ast; use ahash::RandomState;
use crate::error::{LexicalError, LexicalErrorType};
use crate::{
ast,
error::{LexicalError, LexicalErrorType},
};
pub struct ArgumentList { pub struct ArgumentList {
pub args: Vec<ast::Expr>, pub args: Vec<ast::Expr>,
@ -54,38 +57,32 @@ pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, Lexi
let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new()); let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new());
for (name, value) in func_args { for (name, value) in func_args {
match name { if let Some((location, name)) = name {
Some((location, name)) => { if let Some(keyword_name) = &name {
if let Some(keyword_name) = &name { if keyword_names.contains(keyword_name) {
if keyword_names.contains(keyword_name) {
return Err(LexicalError {
error: LexicalErrorType::DuplicateKeywordArgumentError,
location,
});
}
keyword_names.insert(keyword_name.clone());
}
keywords.push(ast::Keyword::new(
location,
ast::KeywordData {
arg: name.map(|name| name.into()),
value: Box::new(value),
},
));
}
None => {
// Allow starred args after keyword arguments.
if !keywords.is_empty() && !is_starred(&value) {
return Err(LexicalError { return Err(LexicalError {
error: LexicalErrorType::PositionalArgumentError, error: LexicalErrorType::DuplicateKeywordArgumentError,
location: value.location, location,
}); });
} }
args.push(value); keyword_names.insert(keyword_name.clone());
} }
keywords.push(ast::Keyword::new(
location,
ast::KeywordData { arg: name.map(String::into), value: Box::new(value) },
));
} else {
// Allow starred args after keyword arguments.
if !keywords.is_empty() && !is_starred(&value) {
return Err(LexicalError {
error: LexicalErrorType::PositionalArgumentError,
location: value.location,
});
}
args.push(value);
} }
} }
Ok(ArgumentList { args, keywords }) Ok(ArgumentList { args, keywords })

View File

@ -1,17 +1,17 @@
//! This module takes care of lexing python source text. //! This module takes care of lexing python source text.
//! //!
//! This means source code is translated into separate tokens. //! This means source code is translated into separate tokens.
use std::{char, cmp::Ordering, num::IntErrorKind, str::FromStr};
pub use super::token::Tok;
use crate::ast::{Location, FileName};
use crate::error::{LexicalError, LexicalErrorType};
use std::char;
use std::cmp::Ordering;
use std::str::FromStr;
use std::num::IntErrorKind;
use unic_emoji_char::is_emoji_presentation; use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start}; use unic_ucd_ident::{is_xid_continue, is_xid_start};
pub use super::token::Tok;
use crate::{
ast::{FileName, Location},
error::{LexicalError, LexicalErrorType},
};
#[derive(Clone, Copy, PartialEq, Debug, Default)] #[derive(Clone, Copy, PartialEq, Debug, Default)]
struct IndentationLevel { struct IndentationLevel {
tabs: usize, tabs: usize,
@ -32,20 +32,14 @@ impl IndentationLevel {
if self.spaces <= other.spaces { if self.spaces <= other.spaces {
Ok(Ordering::Less) Ok(Ordering::Less)
} else { } else {
Err(LexicalError { Err(LexicalError { location, error: LexicalErrorType::TabError })
location,
error: LexicalErrorType::TabError,
})
} }
} }
Ordering::Greater => { Ordering::Greater => {
if self.spaces >= other.spaces { if self.spaces >= other.spaces {
Ok(Ordering::Greater) Ok(Ordering::Greater)
} else { } else {
Err(LexicalError { Err(LexicalError { location, error: LexicalErrorType::TabError })
location,
error: LexicalErrorType::TabError,
})
} }
} }
Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)), Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)),
@ -63,7 +57,7 @@ pub struct Lexer<T: Iterator<Item = char>> {
chr1: Option<char>, chr1: Option<char>,
chr2: Option<char>, chr2: Option<char>,
location: Location, location: Location,
config_comment_prefix: Option<&'static str> config_comment_prefix: Option<&'static str>,
} }
pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! { pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! {
@ -136,11 +130,7 @@ where
T: Iterator<Item = char>, T: Iterator<Item = char>,
{ {
pub fn new(source: T) -> Self { pub fn new(source: T) -> Self {
let mut nlh = NewlineHandler { let mut nlh = NewlineHandler { source, chr0: None, chr1: None };
source,
chr0: None,
chr1: None,
};
nlh.shift(); nlh.shift();
nlh.shift(); nlh.shift();
nlh nlh
@ -169,7 +159,7 @@ where
self.shift(); self.shift();
} else { } else {
// Transform MAC EOL into \n // Transform MAC EOL into \n
self.chr0 = Some('\n') self.chr0 = Some('\n');
} }
} else { } else {
break; break;
@ -189,13 +179,13 @@ where
chars: input, chars: input,
at_begin_of_line: true, at_begin_of_line: true,
nesting: 0, nesting: 0,
indentation_stack: vec![Default::default()], indentation_stack: vec![IndentationLevel::default()],
pending: Vec::new(), pending: Vec::new(),
chr0: None, chr0: None,
location: start, location: start,
chr1: None, chr1: None,
chr2: None, chr2: None,
config_comment_prefix: Some(" nac3:") config_comment_prefix: Some(" nac3:"),
}; };
lxr.next_char(); lxr.next_char();
lxr.next_char(); lxr.next_char();
@ -217,11 +207,9 @@ where
let mut saw_f = false; let mut saw_f = false;
loop { loop {
// Detect r"", f"", b"" and u"" // Detect r"", f"", b"" and u""
if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b') | Some('B')) { if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b' | 'B')) {
saw_b = true; saw_b = true;
} else if !(saw_b || saw_r || saw_u || saw_f) } else if !(saw_b || saw_r || saw_u || saw_f) && matches!(self.chr0, Some('u' | 'U')) {
&& matches!(self.chr0, Some('u') | Some('U'))
{
saw_u = true; saw_u = true;
} else if !(saw_r || saw_u) && (self.chr0 == Some('r') || self.chr0 == Some('R')) { } else if !(saw_r || saw_u) && (self.chr0 == Some('r') || self.chr0 == Some('R')) {
saw_r = true; saw_r = true;
@ -287,15 +275,15 @@ where
let end_pos = self.get_pos(); let end_pos = self.get_pos();
let value = match i128::from_str_radix(&value_text, radix) { let value = match i128::from_str_radix(&value_text, radix) {
Ok(value) => value, Ok(value) => value,
Err(e) => { Err(e) => match e.kind() {
match e.kind() { IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX,
IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX, _ => {
_ => return Err(LexicalError { return Err(LexicalError {
error: LexicalErrorType::OtherError(format!("{:?}", e)), error: LexicalErrorType::OtherError(format!("{e:?}")),
location: start_pos, location: start_pos,
}), })
} }
} },
}; };
Ok((start_pos, Tok::Int { value }, end_pos)) Ok((start_pos, Tok::Int { value }, end_pos))
} }
@ -338,14 +326,7 @@ where
if self.chr0 == Some('j') || self.chr0 == Some('J') { if self.chr0 == Some('j') || self.chr0 == Some('J') {
self.next_char(); self.next_char();
let end_pos = self.get_pos(); let end_pos = self.get_pos();
Ok(( Ok((start_pos, Tok::Complex { real: 0.0, imag: value }, end_pos))
start_pos,
Tok::Complex {
real: 0.0,
imag: value,
},
end_pos,
))
} else { } else {
let end_pos = self.get_pos(); let end_pos = self.get_pos();
Ok((start_pos, Tok::Float { value }, end_pos)) Ok((start_pos, Tok::Float { value }, end_pos))
@ -364,7 +345,7 @@ where
let value = value_text.parse::<i128>().ok(); let value = value_text.parse::<i128>().ok();
let nonzero = match value { let nonzero = match value {
Some(value) => value != 0i128, Some(value) => value != 0i128,
None => true None => true,
}; };
if start_is_zero && nonzero { if start_is_zero && nonzero {
return Err(LexicalError { return Err(LexicalError {
@ -379,7 +360,7 @@ where
/// Consume a sequence of numbers with the given radix, /// Consume a sequence of numbers with the given radix,
/// the digits can be decorated with underscores /// the digits can be decorated with underscores
/// like this: '1_2_3_4' == '1234' /// like this: `'1_2_3_4'` == `'1234'`
fn radix_run(&mut self, radix: u32) -> String { fn radix_run(&mut self, radix: u32) -> String {
let mut value_text = String::new(); let mut value_text = String::new();
@ -412,7 +393,7 @@ where
2 => matches!(c, Some('0'..='1')), 2 => matches!(c, Some('0'..='1')),
8 => matches!(c, Some('0'..='7')), 8 => matches!(c, Some('0'..='7')),
10 => matches!(c, Some('0'..='9')), 10 => matches!(c, Some('0'..='9')),
16 => matches!(c, Some('0'..='9') | Some('a'..='f') | Some('A'..='F')), 16 => matches!(c, Some('0'..='9' | 'a'..='f' | 'A'..='F')),
other => unimplemented!("Radix not implemented: {}", other), other => unimplemented!("Radix not implemented: {}", other),
} }
} }
@ -420,8 +401,8 @@ where
/// Test if we face '[eE][-+]?[0-9]+' /// Test if we face '[eE][-+]?[0-9]+'
fn at_exponent(&self) -> bool { fn at_exponent(&self) -> bool {
match self.chr0 { match self.chr0 {
Some('e') | Some('E') => match self.chr1 { Some('e' | 'E') => match self.chr1 {
Some('+') | Some('-') => matches!(self.chr2, Some('0'..='9')), Some('+' | '-') => matches!(self.chr2, Some('0'..='9')),
Some('0'..='9') => true, Some('0'..='9') => true,
_ => false, _ => false,
}, },
@ -433,19 +414,17 @@ where
fn lex_comment(&mut self) -> Option<Spanned> { fn lex_comment(&mut self) -> Option<Spanned> {
self.next_char(); self.next_char();
// if possibly nac3 pseudocomment, special handling for `# nac3:` // if possibly nac3 pseudocomment, special handling for `# nac3:`
let (mut prefix, mut is_comment) = self let (mut prefix, mut is_comment) =
.config_comment_prefix self.config_comment_prefix.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
// for the correct location of config comment // for the correct location of config comment
let mut start_loc = self.location; let mut start_loc = self.location;
start_loc.go_left(); start_loc.go_left();
loop { loop {
match self.chr0 { match self.chr0 {
Some('\n') => return None, Some('\n') | None => return None,
None => return None,
Some(c) => { Some(c) => {
if let (true, Some(p)) = (is_comment, prefix.next()) { if let (true, Some(p)) = (is_comment, prefix.next()) {
is_comment = is_comment && c == p is_comment = is_comment && c == p;
} else { } else {
// done checking prefix, if is comment then return the spanned // done checking prefix, if is comment then return the spanned
if is_comment { if is_comment {
@ -460,22 +439,20 @@ where
return Some(( return Some((
start_loc, start_loc,
Tok::ConfigComment { content: content.trim().into() }, Tok::ConfigComment { content: content.trim().into() },
self.location self.location,
)); ));
} }
} }
} }
} }
self.next_char(); self.next_char();
}; }
} }
fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> { fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> {
let mut p: u32 = 0u32; let mut p: u32 = 0u32;
let unicode_error = LexicalError { let unicode_error =
error: LexicalErrorType::UnicodeError, LexicalError { error: LexicalErrorType::UnicodeError, location: self.get_pos() };
location: self.get_pos(),
};
for i in 1..=literal_number { for i in 1..=literal_number {
match self.next_char() { match self.next_char() {
Some(c) => match c.to_digit(16) { Some(c) => match c.to_digit(16) {
@ -486,8 +463,8 @@ where
} }
} }
match p { match p {
0xD800..=0xDFFF => Ok(std::char::REPLACEMENT_CHARACTER), 0xD800..=0xDFFF => Ok(char::REPLACEMENT_CHARACTER),
_ => std::char::from_u32(p).ok_or(unicode_error), _ => char::from_u32(p).ok_or(unicode_error),
} }
} }
@ -496,7 +473,7 @@ where
octet_content.push(first); octet_content.push(first);
while octet_content.len() < 3 { while octet_content.len() < 3 {
if let Some('0'..='7') = self.chr0 { if let Some('0'..='7') = self.chr0 {
octet_content.push(self.next_char().unwrap()) octet_content.push(self.next_char().unwrap());
} else { } else {
break; break;
} }
@ -530,10 +507,8 @@ where
} }
} }
} }
unicode_names2::character(&name).ok_or(LexicalError { unicode_names2::character(&name)
error: LexicalErrorType::UnicodeError, .ok_or(LexicalError { error: LexicalErrorType::UnicodeError, location: start_pos })
location: start_pos,
})
} }
fn lex_string( fn lex_string(
@ -566,7 +541,7 @@ where
} else if is_raw { } else if is_raw {
string_content.push('\\'); string_content.push('\\');
if let Some(c) = self.next_char() { if let Some(c) = self.next_char() {
string_content.push(c) string_content.push(c);
} else { } else {
return Err(LexicalError { return Err(LexicalError {
error: LexicalErrorType::StringError, error: LexicalErrorType::StringError,
@ -599,7 +574,7 @@ where
Some('u') if !is_bytes => string_content.push(self.unicode_literal(4)?), Some('u') if !is_bytes => string_content.push(self.unicode_literal(4)?),
Some('U') if !is_bytes => string_content.push(self.unicode_literal(8)?), Some('U') if !is_bytes => string_content.push(self.unicode_literal(8)?),
Some('N') if !is_bytes => { Some('N') if !is_bytes => {
string_content.push(self.parse_unicode_name()?) string_content.push(self.parse_unicode_name()?);
} }
Some(c) => { Some(c) => {
string_content.push('\\'); string_content.push('\\');
@ -650,20 +625,15 @@ where
let end_pos = self.get_pos(); let end_pos = self.get_pos();
let tok = if is_bytes { let tok = if is_bytes {
Tok::Bytes { Tok::Bytes { value: string_content.chars().map(|c| c as u8).collect() }
value: string_content.chars().map(|c| c as u8).collect(),
}
} else { } else {
Tok::String { Tok::String { value: string_content, is_fstring }
value: string_content,
is_fstring,
}
}; };
Ok((start_pos, tok, end_pos)) Ok((start_pos, tok, end_pos))
} }
fn is_identifier_start(&self, c: char) -> bool { fn is_identifier_start(c: char) -> bool {
match c { match c {
'_' | 'a'..='z' | 'A'..='Z' => true, '_' | 'a'..='z' | 'A'..='Z' => true,
'+' | '-' | '*' | '/' | '=' | ' ' | '<' | '>' => false, '+' | '-' | '*' | '/' | '=' | ' ' | '<' | '>' => false,
@ -835,18 +805,14 @@ where
// Check if we have some character: // Check if we have some character:
if let Some(c) = self.chr0 { if let Some(c) = self.chr0 {
// First check identifier: // First check identifier:
if self.is_identifier_start(c) { if Self::is_identifier_start(c) {
let identifier = self.lex_identifier()?; let identifier = self.lex_identifier()?;
self.emit(identifier); self.emit(identifier);
} else if is_emoji_presentation(c) { } else if is_emoji_presentation(c) {
let tok_start = self.get_pos(); let tok_start = self.get_pos();
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit(( self.emit((tok_start, Tok::Name { name: c.to_string().into() }, tok_end));
tok_start,
Tok::Name { name: c.to_string().into() },
tok_end,
));
} else { } else {
self.consume_character(c)?; self.consume_character(c)?;
} }
@ -899,16 +865,13 @@ where
'=' => { '=' => {
let tok_start = self.get_pos(); let tok_start = self.get_pos();
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::EqEqual, tok_end));
self.emit((tok_start, Tok::EqEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::Equal, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::Equal, tok_end));
}
} }
} }
'+' => { '+' => {
@ -934,16 +897,13 @@ where
} }
Some('*') => { Some('*') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::DoubleStarEqual, tok_end));
self.emit((tok_start, Tok::DoubleStarEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::DoubleStar, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleStar, tok_end));
}
} }
} }
_ => { _ => {
@ -963,16 +923,13 @@ where
} }
Some('/') => { Some('/') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::DoubleSlashEqual, tok_end));
self.emit((tok_start, Tok::DoubleSlashEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::DoubleSlash, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleSlash, tok_end));
}
} }
} }
_ => { _ => {
@ -1141,16 +1098,13 @@ where
match self.chr0 { match self.chr0 {
Some('<') => { Some('<') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::LeftShiftEqual, tok_end));
self.emit((tok_start, Tok::LeftShiftEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::LeftShift, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::LeftShift, tok_end));
}
} }
} }
Some('=') => { Some('=') => {
@ -1170,16 +1124,13 @@ where
match self.chr0 { match self.chr0 {
Some('>') => { Some('>') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::RightShiftEqual, tok_end));
self.emit((tok_start, Tok::RightShiftEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::RightShift, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::RightShift, tok_end));
}
} }
} }
Some('=') => { Some('=') => {
@ -1333,13 +1284,14 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{make_tokenizer, NewlineHandler, Tok}; use super::{make_tokenizer, NewlineHandler, Tok};
use nac3ast::FileName;
const WINDOWS_EOL: &str = "\r\n"; const WINDOWS_EOL: &str = "\r\n";
const MAC_EOL: &str = "\r"; const MAC_EOL: &str = "\r";
const UNIX_EOL: &str = "\n"; const UNIX_EOL: &str = "\n";
pub fn lex_source(source: &str) -> Vec<Tok> { pub fn lex_source(source: &str) -> Vec<Tok> {
let lexer = make_tokenizer(source, Default::default()); let lexer = make_tokenizer(source, FileName::default());
lexer.map(|x| x.unwrap().1).collect() lexer.map(|x| x.unwrap().1).collect()
} }
@ -1419,7 +1371,7 @@ class Foo(A, B):
Dedent, Dedent,
Dedent Dedent
] ]
) );
} }
#[test] #[test]
@ -1439,14 +1391,8 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::String { Tok::String { value: "\\\\".to_owned(), is_fstring: false },
value: "\\\\".to_owned(), Tok::String { value: "\\".to_owned(), is_fstring: false },
is_fstring: false,
},
Tok::String {
value: "\\".to_owned(),
is_fstring: false,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1459,27 +1405,13 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::Int { Tok::Int { value: 47i128 },
value: 47i128, Tok::Int { value: 13i128 },
}, Tok::Int { value: 0i128 },
Tok::Int { Tok::Int { value: 123i128 },
value: 13i128,
},
Tok::Int {
value: 0i128,
},
Tok::Int {
value: 123i128,
},
Tok::Float { value: 0.2 }, Tok::Float { value: 0.2 },
Tok::Complex { Tok::Complex { real: 0.0, imag: 2.0 },
real: 0.0, Tok::Complex { real: 0.0, imag: 2.2 },
imag: 2.0,
},
Tok::Complex {
real: 0.0,
imag: 2.2,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1539,21 +1471,13 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::Name { Tok::Name { name: String::from("avariable").into() },
name: String::from("avariable").into(),
},
Tok::Equal, Tok::Equal,
Tok::Int { Tok::Int { value: 99i128 },
value: 99i128
},
Tok::Plus, Tok::Plus,
Tok::Int { Tok::Int { value: 2i128 },
value: 2i128
},
Tok::Minus, Tok::Minus,
Tok::Int { Tok::Int { value: 0i128 },
value: 0i128
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1740,42 +1664,15 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::String { Tok::String { value: String::from("double"), is_fstring: false },
value: String::from("double"), Tok::String { value: String::from("single"), is_fstring: false },
is_fstring: false, Tok::String { value: String::from("can't"), is_fstring: false },
}, Tok::String { value: String::from("\\\""), is_fstring: false },
Tok::String { Tok::String { value: String::from("\t\r\n"), is_fstring: false },
value: String::from("single"), Tok::String { value: String::from("\\g"), is_fstring: false },
is_fstring: false, Tok::String { value: String::from("raw\\'"), is_fstring: false },
}, Tok::String { value: String::from("Đ"), is_fstring: false },
Tok::String { Tok::String { value: String::from("\u{80}\u{0}a"), is_fstring: false },
value: String::from("can't"),
is_fstring: false,
},
Tok::String {
value: String::from("\\\""),
is_fstring: false,
},
Tok::String {
value: String::from("\t\r\n"),
is_fstring: false,
},
Tok::String {
value: String::from("\\g"),
is_fstring: false,
},
Tok::String {
value: String::from("raw\\'"),
is_fstring: false,
},
Tok::String {
value: String::from("Đ"),
is_fstring: false,
},
Tok::String {
value: String::from("\u{80}\u{0}a"),
is_fstring: false,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1830,7 +1727,7 @@ class Foo(A, B):
#[test] #[test]
fn test_escape_char_in_byte_literal() { fn test_escape_char_in_byte_literal() {
// backslash does not escape // backslash does not escape
let source = r##"b"omkmok\Xaa""##; let source = r#"b"omkmok\Xaa""#;
let tokens = lex_source(source); let tokens = lex_source(source);
let res = vec![111, 109, 107, 109, 111, 107, 92, 88, 97, 97]; let res = vec![111, 109, 107, 109, 111, 107, 92, 88, 97, 97];
assert_eq!(tokens, vec![Tok::Bytes { value: res }, Tok::Newline]); assert_eq!(tokens, vec![Tok::Bytes { value: res }, Tok::Newline]);
@ -1840,41 +1737,17 @@ class Foo(A, B):
fn test_raw_byte_literal() { fn test_raw_byte_literal() {
let source = r"rb'\x1z'"; let source = r"rb'\x1z'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"\\x1z".to_vec() }, Tok::Newline]);
tokens,
vec![
Tok::Bytes {
value: b"\\x1z".to_vec()
},
Tok::Newline
]
);
let source = r"rb'\\'"; let source = r"rb'\\'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"\\\\".to_vec() }, Tok::Newline]);
tokens,
vec![
Tok::Bytes {
value: b"\\\\".to_vec()
},
Tok::Newline
]
)
} }
#[test] #[test]
fn test_escape_octet() { fn test_escape_octet() {
let source = r##"b'\43a\4\1234'"##; let source = r"b'\43a\4\1234'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"#a\x04S4".to_vec() }, Tok::Newline]);
tokens,
vec![
Tok::Bytes {
value: b"#a\x04S4".to_vec()
},
Tok::Newline
]
)
} }
#[test] #[test]
@ -1883,13 +1756,7 @@ class Foo(A, B):
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![Tok::String { value: "\u{2002}".to_owned(), is_fstring: false }, Tok::Newline]
Tok::String { );
value: "\u{2002}".to_owned(),
is_fstring: false,
},
Tok::Newline
]
)
} }
} }

View File

@ -15,6 +15,24 @@
//! //!
//! ``` //! ```
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::enum_glob_use,
clippy::fn_params_excessive_bools,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
#[macro_use] #[macro_use]
extern crate log; extern crate log;
use lalrpop_util::lalrpop_mod; use lalrpop_util::lalrpop_mod;
@ -27,9 +45,16 @@ pub mod lexer;
pub mod mode; pub mod mode;
pub mod parser; pub mod parser;
lalrpop_mod!( lalrpop_mod!(
#[allow(clippy::all)] #[allow(
#[allow(unused)] future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
unused,
clippy::all,
clippy::pedantic
)]
python python
); );
pub mod token;
pub mod config_comment_helper; pub mod config_comment_helper;
pub mod token;

View File

@ -7,11 +7,14 @@
use std::iter; use std::iter;
use crate::ast::{self, FileName}; use nac3ast::Location;
use crate::error::ParseError;
use crate::lexer;
pub use crate::mode::Mode; pub use crate::mode::Mode;
use crate::python; use crate::{
ast::{self, FileName},
error::ParseError,
lexer, python,
};
/* /*
* Parse python code. * Parse python code.
@ -63,7 +66,7 @@ pub fn parse_program(source: &str, file: FileName) -> Result<ast::Suite, ParseEr
/// ///
/// ``` /// ```
pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> { pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
parse(source, Mode::Expression, Default::default()).map(|top| match top { parse(source, Mode::Expression, FileName::default()).map(|top| match top {
ast::Mod::Expression { body } => *body, ast::Mod::Expression { body } => *body,
_ => unreachable!(), _ => unreachable!(),
}) })
@ -72,12 +75,10 @@ pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
// Parse a given source code // Parse a given source code
pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> { pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> {
let lxr = lexer::make_tokenizer(source, file); let lxr = lexer::make_tokenizer(source, file);
let marker_token = (Default::default(), mode.to_marker(), Default::default()); let marker_token = (Location::default(), mode.to_marker(), Location::default());
let tokenizer = iter::once(Ok(marker_token)).chain(lxr); let tokenizer = iter::once(Ok(marker_token)).chain(lxr);
python::TopParser::new() python::TopParser::new().parse(tokenizer).map_err(ParseError::from)
.parse(tokenizer)
.map_err(ParseError::from)
} }
#[cfg(test)] #[cfg(test)]
@ -86,42 +87,42 @@ mod tests {
#[test] #[test]
fn test_parse_empty() { fn test_parse_empty() {
let parse_ast = parse_program("", Default::default()).unwrap(); let parse_ast = parse_program("", FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_print_hello() { fn test_parse_print_hello() {
let source = String::from("print('Hello world')"); let source = String::from("print('Hello world')");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_print_2() { fn test_parse_print_2() {
let source = String::from("print('Hello world', 2)"); let source = String::from("print('Hello world', 2)");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_kwargs() { fn test_parse_kwargs() {
let source = String::from("my_func('positional', keyword=2)"); let source = String::from("my_func('positional', keyword=2)");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_if_elif_else() { fn test_parse_if_elif_else() {
let source = String::from("if 1: 10\nelif 2: 20\nelse: 30"); let source = String::from("if 1: 10\nelif 2: 20\nelse: 30");
let parse_ast = parse_program(&source, Default::default()).unwrap(); let parse_ast = parse_program(&source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_parse_lambda() { fn test_parse_lambda() {
let source = "lambda x, y: x * y"; // lambda(x, y): x * y"; let source = "lambda x, y: x * y"; // lambda(x, y): x * y";
let parse_ast = parse_program(source, Default::default()).unwrap(); let parse_ast = parse_program(source, FileName::default()).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
@ -129,7 +130,7 @@ mod tests {
fn test_parse_tuples() { fn test_parse_tuples() {
let source = "a, b = 4, 5"; let source = "a, b = 4, 5";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -140,7 +141,7 @@ class Foo(A, B):
pass pass
def method_with_default(self, arg='default'): def method_with_default(self, arg='default'):
pass"; pass";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -183,7 +184,7 @@ while i < 2: # nac3: 4
# nac3: if1 # nac3: if1
if 1: # nac3: if2 if 1: # nac3: if2
3"; 3";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -196,7 +197,7 @@ while test: # nac3: while3
# nac3: simple assign0 # nac3: simple assign0
a = 3 # nac3: simple assign1 a = 3 # nac3: simple assign1
"; ";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -215,7 +216,7 @@ if a: # nac3: small2
for i in a: # nac3: for1 for i in a: # nac3: for1
pass pass
"; ";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, FileName::default()).unwrap());
} }
#[test] #[test]
@ -224,6 +225,6 @@ for i in a: # nac3: for1
if a: # nac3: something if a: # nac3: something
a = 3 a = 3
"; ";
assert!(parse_program(source, Default::default()).is_err()); assert!(parse_program(source, FileName::default()).is_err());
} }
} }

View File

@ -1,6 +1,7 @@
//! Different token definitions. //! Different token definitions.
//! Loosely based on token.h from CPython source: //! Loosely based on token.h from CPython source:
use std::fmt::{self, Write}; use std::fmt::{self, Write};
use crate::ast; use crate::ast;
/// Python source code can be tokenized in a sequence of these tokens. /// Python source code can be tokenized in a sequence of these tokens.
@ -111,15 +112,23 @@ impl fmt::Display for Tok {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Tok::*; use Tok::*;
match self { match self {
Name { name } => write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name)), Name { name } => {
Int { value } => if *value != i128::MAX { write!(f, "'{}'", value) } else { write!(f, "'#OFL#'") }, write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name))
Float { value } => write!(f, "'{}'", value), }
Complex { real, imag } => write!(f, "{}j{}", real, imag), Int { value } => {
if *value == i128::MAX {
write!(f, "'#OFL#'")
} else {
write!(f, "'{value}'")
}
}
Float { value } => write!(f, "'{value}'"),
Complex { real, imag } => write!(f, "{real}j{imag}"),
String { value, is_fstring } => { String { value, is_fstring } => {
if *is_fstring { if *is_fstring {
write!(f, "f")? write!(f, "f")?;
} }
write!(f, "{:?}", value) write!(f, "{value:?}")
} }
Bytes { value } => { Bytes { value } => {
write!(f, "b\"")?; write!(f, "b\"")?;
@ -129,12 +138,16 @@ impl fmt::Display for Tok {
10 => f.write_str("\\n")?, 10 => f.write_str("\\n")?,
13 => f.write_str("\\r")?, 13 => f.write_str("\\r")?,
32..=126 => f.write_char(*i as char)?, 32..=126 => f.write_char(*i as char)?,
_ => write!(f, "\\x{:02x}", i)?, _ => write!(f, "\\x{i:02x}")?,
} }
} }
f.write_str("\"") f.write_str("\"")
} }
ConfigComment { content } => write!(f, "ConfigComment: '{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)), ConfigComment { content } => write!(
f,
"ConfigComment: '{}'",
ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)
),
Newline => f.write_str("Newline"), Newline => f.write_str("Newline"),
Indent => f.write_str("Indent"), Indent => f.write_str("Indent"),
Dedent => f.write_str("Dedent"), Dedent => f.write_str("Dedent"),

View File

@ -2,14 +2,15 @@
name = "nac3standalone" name = "nac3standalone"
version = "0.1.0" version = "0.1.0"
authors = ["M-Labs"] authors = ["M-Labs"]
edition = "2018" edition = "2021"
[features]
no-escape-analysis = ["nac3core/no-escape-analysis"]
[dependencies] [dependencies]
parking_lot = "0.11.1" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
[dependencies.inkwell] [dependencies.clap]
version = "0.1.0-beta.4" version = "4.5"
default-features = false features = ["derive"]
features = ["llvm13-0", "target-x86", "target-arm", "target-riscv", "no-libffi-linking"]

4
nac3standalone/demo/.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*.bc
*.ll
*.o
/demo

View File

@ -0,0 +1,68 @@
#!/usr/bin/env bash
set -e
if [ -z "$1" ]; then
echo "No argument supplied"
exit 1
fi
declare -a nac3args
while [ $# -gt 1 ]; do
case "$1" in
--help)
echo "Usage: check_demo.sh [--debug] [-i686] -- [NAC3ARGS...] demo"
exit
;;
--debug)
debug=1
;;
-i686)
i686=1
;;
--)
shift
break
;;
*)
echo "Unrecognized argument \"$1\""
exit 1
;;
esac
shift
done
while [ $# -gt 1 ]; do
nac3args+=("$1")
shift
done
demo="$1"
echo "### Checking $demo..."
echo ">>>>>> Running $demo with the Python interpreter"
./interpret_demo.py "$demo" > interpreted.log
if [ -n "$i686" ]; then
echo "...... Trying NAC3's 32-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh -i686 --out run_32.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_32.log
fi
echo "...... Trying NAC3's 64-bit code generator output"
if [ -n "$debug" ]; then
./run_demo.sh --debug --out run_64.log -- "${nac3args[@]}" "$demo"
else
./run_demo.sh --out run_64.log -- "${nac3args[@]}" "$demo"
fi
diff -Nau interpreted.log run_64.log
echo "...... OK"
rm -f interpreted.log \
run_32.log run_64.log

View File

@ -2,14 +2,15 @@
set -e set -e
if [ "$1" == "--help" ]; then
echo "Usage: check_demos.sh [CHECKARGS...] [--] [NAC3ARGS...]"
exit
fi
count=0 count=0
for demo in src/*.py; do for demo in src/*.py; do
echo -n "checking $demo... " ./check_demo.sh "$@" "$demo"
./interpret_demo.py $demo > interpreted.log ((count += 1))
./run_demo.sh $demo > run.log
diff -Nau interpreted.log run.log
echo "ok"
let "count+=1"
done done
echo "Ran $count demo checks - PASSED" echo "Ran $count demo checks - PASSED"

138
nac3standalone/demo/demo.c Normal file
View File

@ -0,0 +1,138 @@
#include <inttypes.h>
#include <math.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
double dbl_nan(void) {
return NAN;
}
double dbl_inf(void) {
return INFINITY;
}
void output_bool(bool x) {
puts(x ? "True" : "False");
}
void output_int32(int32_t x) {
printf("%" PRId32 "\n", x);
}
void output_int64(int64_t x) {
printf("%" PRId64 "\n", x);
}
void output_uint32(uint32_t x) {
printf("%" PRIu32 "\n", x);
}
void output_uint64(uint64_t x) {
printf("%" PRIu64 "\n", x);
}
void output_float64(double x) {
if (isnan(x)) {
puts("nan");
} else {
printf("%f\n", x);
}
}
void output_range(int32_t range[3]) {
printf("range(");
printf("%d, %d", range[0], range[1]);
if (range[2] != 1) {
printf(", %d", range[2]);
}
puts(")");
}
void output_asciiart(int32_t x) {
static const char* chars = " .,-:;i+hHM$*#@ ";
if (x < 0) {
putchar('\n');
} else {
putchar(chars[x]);
}
}
struct cslice {
void* data;
size_t len;
};
void output_int32_list(struct cslice* slice) {
const int32_t* data = (int32_t*)slice->data;
putchar('[');
for (size_t i = 0; i < slice->len; ++i) {
if (i == slice->len - 1) {
printf("%d", data[i]);
} else {
printf("%d, ", data[i]);
}
}
putchar(']');
putchar('\n');
}
void output_str(struct cslice* slice) {
const char* data = (const char*)slice->data;
for (size_t i = 0; i < slice->len; ++i) {
putchar(data[i]);
}
}
void output_strln(struct cslice* slice) {
output_str(slice);
putchar('\n');
}
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice* slice) {
int i;
void* ptr = (void*)&i;
return (uintptr_t)ptr;
}
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
printf("__nac3_personality(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
exit(101);
__builtin_unreachable();
}
// See `struct Exception<'a>` in
// https://github.com/m-labs/artiq/blob/master/artiq/firmware/libeh/eh_artiq.rs
struct Exception {
uint32_t id;
struct cslice file;
uint32_t line;
uint32_t column;
struct cslice function;
struct cslice message;
int64_t param[3];
};
uint32_t __nac3_raise(struct Exception* e) {
printf("__nac3_raise called. Exception details:\n");
printf(" ID: %" PRIu32 "\n", e->id);
printf(" Location: %*s:%" PRIu32 ":%" PRIu32 "\n", (int)e->file.len, (const char*)e->file.data, e->line,
e->column);
printf(" Function: %*s\n", (int)e->function.len, (const char*)e->function.data);
printf(" Message: \"%*s\"\n", (int)e->message.len, (const char*)e->message.data);
printf(" Params: {0}=%" PRId64 ", {1}=%" PRId64 ", {2}=%" PRId64 "\n", e->param[0], e->param[1], e->param[2]);
exit(101);
__builtin_unreachable();
}
void __nac3_end_catch(void) {}
extern int32_t run(void);
int main(void) {
run();
}

View File

@ -1,90 +0,0 @@
mod cslice {
// copied from https://github.com/dherman/cslice
use std::marker::PhantomData;
use std::slice;
#[repr(C)]
#[derive(Clone, Copy)]
pub struct CSlice<'a, T> {
base: *const T,
len: usize,
marker: PhantomData<&'a ()>,
}
impl<'a, T> AsRef<[T]> for CSlice<'a, T> {
fn as_ref(&self) -> &[T] {
unsafe { slice::from_raw_parts(self.base, self.len) }
}
}
}
#[no_mangle]
pub extern "C" fn output_int32(x: i32) {
println!("{}", x);
}
#[no_mangle]
pub extern "C" fn output_int64(x: i64) {
println!("{}", x);
}
#[no_mangle]
pub extern "C" fn output_uint32(x: u32) {
println!("{}", x);
}
#[no_mangle]
pub extern "C" fn output_uint64(x: u64) {
println!("{}", x);
}
#[no_mangle]
pub extern "C" fn output_float64(x: f64) {
// debug output to preserve the digits after the decimal points
// to match python `print` function
println!("{:?}", x);
}
#[no_mangle]
pub extern "C" fn output_asciiart(x: i32) {
let chars = " .,-:;i+hHM$*#@ ";
if x < 0 {
println!("");
} else {
print!("{}", chars.chars().nth(x as usize).unwrap());
}
}
#[no_mangle]
pub extern "C" fn output_int32_list(x: &cslice::CSlice<i32>) {
print!("[");
let mut it = x.as_ref().iter().peekable();
while let Some(e) = it.next() {
if it.peek().is_none() {
print!("{}", e);
} else {
print!("{}, ", e);
}
}
println!("]");
}
#[no_mangle]
pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _context: u32) -> u32 {
unimplemented!();
}
#[no_mangle]
pub extern "C" fn __nac3_raise(_state: u32, _exception_object: u32, _context: u32) -> u32 {
unimplemented!();
}
extern "C" {
fn run() -> i32;
}
fn main() {
unsafe {
run();
}
}

View File

@ -3,10 +3,15 @@
import sys import sys
import importlib.util import importlib.util
import importlib.machinery import importlib.machinery
import math
import numpy as np
import numpy.typing as npt
import scipy as sp
import pathlib import pathlib
from numpy import int32, int64, uint32, uint64 from numpy import int32, int64, uint32, uint64
from typing import TypeVar, Generic from scipy import special
from typing import TypeVar, Generic, Literal, Union
T = TypeVar('T') T = TypeVar('T')
class Option(Generic[T]): class Option(Generic[T]):
@ -41,26 +46,99 @@ def Some(v: T) -> Option[T]:
none = Option(None) none = Option(None)
class _ConstGenericMarker:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
N = TypeVar("N", bound=np.uint64)
class _NDArrayDummy(Generic[T, N]):
pass
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
def _bool(x):
if isinstance(x, np.ndarray):
return np.bool_(x)
else:
return bool(x)
def _float(x):
if isinstance(x, np.ndarray):
return np.float_(x)
else:
return float(x)
def round_away_zero(x):
if isinstance(x, np.ndarray):
return np.vectorize(round_away_zero)(x)
else:
if x >= 0.0:
return math.floor(x + 0.5)
else:
return math.ceil(x - 0.5)
def _floor(x):
if isinstance(x, np.ndarray):
return np.vectorize(_floor)(x)
else:
return math.floor(x)
def _ceil(x):
if isinstance(x, np.ndarray):
return np.vectorize(_ceil)(x)
else:
return math.ceil(x)
def patch(module): def patch(module):
def dbl_nan():
return np.nan
def dbl_inf():
return np.inf
def output_asciiart(x): def output_asciiart(x):
if x < 0: if x < 0:
sys.stdout.write("\n") sys.stdout.write("\n")
else: else:
sys.stdout.write(" .,-:;i+hHM$*#@ "[x]) sys.stdout.write(" .,-:;i+hHM$*#@ "[x])
def output_float(x):
print("%f" % x)
def output_strln(x):
print(x, end='')
def dbg_stack_address(_):
return 0
def extern(fun): def extern(fun):
name = fun.__name__ name = fun.__name__
if name == "output_asciiart": if name == "dbl_nan":
return dbl_nan
elif name == "dbl_inf":
return dbl_inf
elif name == "output_asciiart":
return output_asciiart return output_asciiart
elif name == "output_float64":
return output_float
elif name == "output_str":
return output_strln
elif name in { elif name in {
"output_bool",
"output_int32", "output_int32",
"output_int64", "output_int64",
"output_int32_list", "output_int32_list",
"output_uint32", "output_uint32",
"output_uint64", "output_uint64",
"output_float64" "output_strln",
"output_range",
}: }:
return print return print
elif name == "dbg_stack_address":
return dbg_stack_address
else: else:
raise NotImplementedError raise NotImplementedError
@ -68,13 +146,102 @@ def patch(module):
module.int64 = int64 module.int64 = int64
module.uint32 = uint32 module.uint32 = uint32
module.uint64 = uint64 module.uint64 = uint64
module.bool = _bool
module.float = _float
module.TypeVar = TypeVar module.TypeVar = TypeVar
module.ConstGeneric = ConstGeneric
module.Generic = Generic module.Generic = Generic
module.Literal = Literal
module.extern = extern module.extern = extern
module.Option = Option module.Option = Option
module.Some = Some module.Some = Some
module.none = none module.none = none
# Builtin Math functions
module.round = round_away_zero
module.round64 = round_away_zero
module.np_round = np.round
module.floor = _floor
module.floor64 = _floor
module.np_floor = np.floor
module.ceil = _ceil
module.ceil64 = _ceil
module.np_ceil = np.ceil
# NumPy NDArray factory functions
module.ndarray = NDArray
module.np_ndarray = np.ndarray
module.np_empty = np.empty
module.np_zeros = np.zeros
module.np_ones = np.ones
module.np_full = np.full
module.np_eye = np.eye
module.np_identity = np.identity
module.np_array = np.array
# NumPy Math functions
module.np_isnan = np.isnan
module.np_isinf = np.isinf
module.np_min = np.min
module.np_minimum = np.minimum
module.np_argmin = np.argmin
module.np_max = np.max
module.np_maximum = np.maximum
module.np_argmax = np.argmax
module.np_sin = np.sin
module.np_cos = np.cos
module.np_exp = np.exp
module.np_exp2 = np.exp2
module.np_log = np.log
module.np_log10 = np.log10
module.np_log2 = np.log2
module.np_fabs = np.fabs
module.np_trunc = np.trunc
module.np_sqrt = np.sqrt
module.np_rint = np.rint
module.np_tan = np.tan
module.np_arcsin = np.arcsin
module.np_arccos = np.arccos
module.np_arctan = np.arctan
module.np_sinh = np.sinh
module.np_cosh = np.cosh
module.np_tanh = np.tanh
module.np_arcsinh = np.arcsinh
module.np_arccosh = np.arccosh
module.np_arctanh = np.arctanh
module.np_expm1 = np.expm1
module.np_cbrt = np.cbrt
module.np_arctan2 = np.arctan2
module.np_copysign = np.copysign
module.np_fmax = np.fmax
module.np_fmin = np.fmin
module.np_ldexp = np.ldexp
module.np_hypot = np.hypot
module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions
module.sp_spec_erf = special.erf
module.sp_spec_erfc = special.erfc
module.sp_spec_gamma = special.gamma
module.sp_spec_gammaln = special.gammaln
module.sp_spec_j0 = special.j0
module.sp_spec_j1 = special.j1
# Linalg functions
module.np_dot = np.dot
module.np_linalg_cholesky = np.linalg.cholesky
module.np_linalg_qr = np.linalg.qr
module.np_linalg_svd = np.linalg.svd
module.np_linalg_inv = np.linalg.inv
module.np_linalg_pinv = np.linalg.pinv
module.np_linalg_matrix_power = np.linalg.matrix_power
module.np_linalg_det = np.linalg.det
module.sp_linalg_lu = lambda x: sp.linalg.lu(x, True)
module.sp_linalg_schur = sp.linalg.schur
module.sp_linalg_hessenberg = lambda x: sp.linalg.hessenberg(x, True)
def file_import(filename, prefix="file_import_"): def file_import(filename, prefix="file_import_"):
filename = pathlib.Path(filename) filename = pathlib.Path(filename)

Some files were not shown because too many files have changed in this diff Show More