forked from M-Labs/nac3
Compare commits
1 Commits
122983f11c
...
8dda6d5fe3
Author | SHA1 | Date | |
---|---|---|---|
8dda6d5fe3 |
@ -1,32 +0,0 @@
|
||||
BasedOnStyle: LLVM
|
||||
|
||||
Language: Cpp
|
||||
Standard: Cpp11
|
||||
|
||||
AccessModifierOffset: -1
|
||||
AlignEscapedNewlines: Left
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakTemplateDeclarations: Yes
|
||||
AllowAllParametersOfDeclarationOnNextLine: false
|
||||
AllowShortFunctionsOnASingleLine: Inline
|
||||
BinPackParameters: false
|
||||
BreakBeforeBinaryOperators: NonAssignment
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializers: AfterColon
|
||||
BreakInheritanceList: AfterColon
|
||||
ColumnLimit: 120
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ContinuationIndentWidth: 4
|
||||
DerivePointerAlignment: false
|
||||
IndentCaseLabels: true
|
||||
IndentPPDirectives: None
|
||||
IndentWidth: 4
|
||||
MaxEmptyLinesToKeep: 1
|
||||
PointerAlignment: Left
|
||||
ReflowComments: true
|
||||
SortIncludes: false
|
||||
SortUsingDeclarations: true
|
||||
SpaceAfterTemplateKeyword: false
|
||||
SpacesBeforeTrailingComments: 2
|
||||
TabWidth: 4
|
||||
UseTab: Never
|
83
Cargo.lock
generated
83
Cargo.lock
generated
@ -117,12 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.15"
|
||||
version = "1.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
checksum = "e9e8aabfac534be767c909e0690571677d49f41bd8465ae876fe043d52ba5292"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
@ -132,9 +129,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.16"
|
||||
version = "4.5.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed6719fffa43d0d87e5fd8caeab59be1554fb028cd30edc88fc4369b17971019"
|
||||
checksum = "11d8838454fda655dafd3accb2b6e2bea645b9e4078abe84a22ceb947235c5cc"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
@ -161,7 +158,7 @@ dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -310,9 +307,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.1.1"
|
||||
version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
|
||||
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
@ -388,9 +385,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.4.0"
|
||||
version = "2.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93ead53efc7ea8ed3cfb0c79fc8023fbb782a5432b52830b6518941cebe6505c"
|
||||
checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.14.5",
|
||||
@ -424,7 +421,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -510,9 +507,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.158"
|
||||
version = "0.2.155"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
|
||||
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
@ -619,7 +616,7 @@ name = "nac3core"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"crossbeam",
|
||||
"indexmap 2.4.0",
|
||||
"indexmap 2.3.0",
|
||||
"indoc",
|
||||
"inkwell",
|
||||
"insta",
|
||||
@ -709,7 +706,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db"
|
||||
dependencies = [
|
||||
"fixedbitset",
|
||||
"indexmap 2.4.0",
|
||||
"indexmap 2.3.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -752,7 +749,7 @@ dependencies = [
|
||||
"phf_shared 0.11.2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -856,7 +853,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -869,14 +866,14 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.37"
|
||||
version = "1.0.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
|
||||
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
@ -942,9 +939,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "redox_users"
|
||||
version = "0.4.6"
|
||||
version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
|
||||
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"libredox",
|
||||
@ -989,9 +986,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.35"
|
||||
version = "0.38.34"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f"
|
||||
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
@ -1035,29 +1032,29 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.209"
|
||||
version = "1.0.206"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09"
|
||||
checksum = "5b3e4cd94123dd520a128bcd11e34d9e9e423e7e3e50425cb1b4b1e3549d0284"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.209"
|
||||
version = "1.0.206"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170"
|
||||
checksum = "fabfb6138d2383ea8208cf98ccf69cdfb1aff4088460681d84189aa259762f97"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.127"
|
||||
version = "1.0.124"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad"
|
||||
checksum = "66ad62847a56b3dba58cc891acd13884b9c61138d330c0d7b6181713d4fce38d"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
@ -1077,12 +1074,6 @@ dependencies = [
|
||||
"yaml-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shlex"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "similar"
|
||||
version = "2.6.0"
|
||||
@ -1147,7 +1138,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1163,9 +1154,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.76"
|
||||
version = "2.0.74"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525"
|
||||
checksum = "1fceb41e3d546d0bd83421d3409b1460cc7444cd389341a4c880fe7a042cb3d7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -1232,7 +1223,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1310,9 +1301,9 @@ checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.5"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
|
||||
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
@ -1510,5 +1501,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.76",
|
||||
"syn 2.0.74",
|
||||
]
|
||||
|
6
flake.lock
generated
6
flake.lock
generated
@ -2,11 +2,11 @@
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1724819573,
|
||||
"narHash": "sha256-GnR7/ibgIH1vhoy8cYdmXE6iyZqKqFxQSVkFgosBh6w=",
|
||||
"lastModified": 1721924956,
|
||||
"narHash": "sha256-Sb1jlyRO+N8jBXEX9Pg9Z1Qb8Bw9QyOgLDNMEpmjZ2M=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "71e91c409d1e654808b2621f28a327acfdad8dc2",
|
||||
"rev": "5ad6a14c6bf098e98800b091668718c336effc95",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
@ -180,7 +180,9 @@
|
||||
clippy
|
||||
pre-commit
|
||||
rustfmt
|
||||
rust-analyzer
|
||||
];
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
shellHook =
|
||||
''
|
||||
export DEMO_LINALG_STUB=${packages.x86_64-linux.demo-linalg-stub}/lib/liblinalg.a
|
||||
|
@ -1,12 +1,9 @@
|
||||
use nac3core::{
|
||||
codegen::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
|
||||
NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
|
||||
},
|
||||
classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor},
|
||||
expr::{destructure_range, gen_call},
|
||||
irrt::call_ndarray_calc_size,
|
||||
llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave},
|
||||
llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave},
|
||||
stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
@ -20,9 +17,9 @@ use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
module::Linkage,
|
||||
types::{BasicType, IntType},
|
||||
values::{BasicValueEnum, PointerValue, StructValue},
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
types::IntType,
|
||||
values::{BasicValueEnum, StructValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
|
||||
use pyo3::{
|
||||
@ -32,7 +29,6 @@ use pyo3::{
|
||||
|
||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
use inkwell::values::IntValue;
|
||||
use itertools::Itertools;
|
||||
use std::{
|
||||
collections::{hash_map::DefaultHasher, HashMap},
|
||||
@ -131,7 +127,7 @@ impl<'a> ArtiqCodeGenerator<'a> {
|
||||
/// (possibly indirect) `parallel` block.
|
||||
///
|
||||
/// * `store_name` - The LLVM value name for the pointer to `end`. `.addr` will be appended to
|
||||
/// the end of the provided value name.
|
||||
/// the end of the provided value name.
|
||||
fn timeline_update_end_max(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
@ -426,10 +422,7 @@ fn gen_rpc_tag(
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
assert!(
|
||||
(0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims),
|
||||
"Only NDArrays of sizes between 0 and 255 can be RPCed"
|
||||
);
|
||||
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
|
||||
|
||||
buffer.push(b'a');
|
||||
buffer.push((ndarray_ndims & 0xFF) as u8);
|
||||
@ -441,383 +434,6 @@ fn gen_rpc_tag(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Formats an RPC argument to conform to the expected format required by `send_value`.
|
||||
///
|
||||
/// See `artiq/firmware/libproto_artiq/rpc_proto.rs` for the expected format.
|
||||
fn format_rpc_arg<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
(arg, arg_ty, arg_idx): (BasicValueEnum<'ctx>, Type, usize),
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let arg_slot = match &*ctx.unifier.get_ty_immutable(arg_ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
// NAC3: NDArray = { usize, usize*, T* }
|
||||
// libproto_artiq: NDArray = [data[..], dim_sz[..]]
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||
let llvm_arg_ty =
|
||||
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
|
||||
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
|
||||
|
||||
let llvm_usize_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let llvm_pdata_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(),
|
||||
llvm_usize,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let dims_buf_sz =
|
||||
ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||
|
||||
let buffer_size =
|
||||
ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap();
|
||||
|
||||
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
buffer.base_ptr(ctx, generator),
|
||||
llvm_arg.ptr_to_data(ctx),
|
||||
llvm_pdata_sizeof,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
let pbuffer_dims_begin =
|
||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
pbuffer_dims_begin,
|
||||
llvm_arg.dim_sizes().base_ptr(ctx, generator),
|
||||
dims_buf_sz,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
|
||||
buffer.base_ptr(ctx, generator)
|
||||
}
|
||||
|
||||
_ => {
|
||||
let arg_slot = generator
|
||||
.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{arg_idx}")))
|
||||
.unwrap();
|
||||
ctx.builder.build_store(arg_slot, arg).unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_bitcast(arg_slot, llvm_pi8, "rpc.arg")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
debug_assert_eq!(arg_slot.get_type(), llvm_pi8);
|
||||
|
||||
arg_slot
|
||||
}
|
||||
|
||||
/// Formats an RPC return value to conform to the expected format required by NAC3.
|
||||
fn format_rpc_ret<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ret_ty: Type,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
// -- receive value:
|
||||
// T result = {
|
||||
// void *ret_ptr = alloca(sizeof(T));
|
||||
// void *ptr = ret_ptr;
|
||||
// loop: int size = rpc_recv(ptr);
|
||||
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
||||
// if(size) { ptr = alloca(size); goto loop; }
|
||||
// else *(T*)ret_ptr
|
||||
// }
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
|
||||
});
|
||||
|
||||
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
|
||||
ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
|
||||
return None;
|
||||
}
|
||||
|
||||
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let current_function = prehead_bb.get_parent().unwrap();
|
||||
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
||||
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
||||
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
||||
|
||||
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
||||
|
||||
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
// Round `val` up to its modulo `power_of_two`
|
||||
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
val: IntValue<'ctx>,
|
||||
power_of_two: IntValue<'ctx>| {
|
||||
debug_assert_eq!(
|
||||
val.get_type().get_bit_width(),
|
||||
power_of_two.get_type().get_bit_width()
|
||||
);
|
||||
|
||||
let llvm_val_t = val.get_type();
|
||||
|
||||
let max_rem = ctx
|
||||
.builder
|
||||
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_and(
|
||||
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
|
||||
ctx.builder.build_not(max_rem, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
// Setup types
|
||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||
|
||||
// Allocate the resulting ndarray
|
||||
// A condition after format_rpc_ret ensures this will not be popped this off.
|
||||
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
|
||||
|
||||
// Setup ndims
|
||||
let ndims =
|
||||
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
|
||||
assert_eq!(values.len(), 1);
|
||||
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
// Set `ndarray.ndims`
|
||||
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
||||
// Allocate `ndarray.shape` [size_t; ndims]
|
||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
||||
|
||||
/*
|
||||
ndarray now:
|
||||
- .ndims: initialized
|
||||
- .shape: allocated but uninitialized .shape
|
||||
- .data: uninitialized
|
||||
*/
|
||||
|
||||
let llvm_usize_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let llvm_pdata_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
llvm_ret_ty.element_type().size_of().unwrap(),
|
||||
llvm_usize,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let llvm_elem_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
|
||||
.unwrap();
|
||||
|
||||
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
|
||||
// (4 + 4 * ndims) bytes with 8-byte alignment
|
||||
let sizeof_dims =
|
||||
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||
let unaligned_buffer_size =
|
||||
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
||||
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
|
||||
|
||||
let stackptr = call_stacksave(ctx, None);
|
||||
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
|
||||
let buffer = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
llvm_i8_8,
|
||||
ctx.builder
|
||||
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
|
||||
.unwrap(),
|
||||
"rpc.buffer",
|
||||
)
|
||||
.unwrap();
|
||||
let buffer = ctx
|
||||
.builder
|
||||
.build_bitcast(buffer, llvm_pi8, "")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
||||
|
||||
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
||||
//
|
||||
// The returned value is the number of bytes for `ndarray.data`.
|
||||
let ndarray_nbytes = ctx
|
||||
.build_call_or_invoke(
|
||||
rpc_recv,
|
||||
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
|
||||
"rpc.size.next",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
// debug_assert(ndarray_nbytes > 0)
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::UGT,
|
||||
ndarray_nbytes,
|
||||
ndarray_nbytes.get_type().const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
"0:AssertionError",
|
||||
"Unexpected RPC termination for ndarray - Expected data buffer next",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
// Copy shape from the buffer to `ndarray.shape`.
|
||||
let pbuffer_dims =
|
||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
||||
pbuffer_dims,
|
||||
sizeof_dims,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
// Restore stack from before allocation of buffer
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
// Allocate `ndarray.data`.
|
||||
// `ndarray.shape` must be initialized beforehand in this implementation
|
||||
// (for ndarray.create_data() to know how many elements to allocate)
|
||||
let num_elements =
|
||||
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
|
||||
|
||||
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let sizeof_data =
|
||||
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::UGE,
|
||||
sizeof_data,
|
||||
ndarray_nbytes,
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:AssertionError",
|
||||
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
|
||||
[Some(sizeof_data), Some(ndarray_nbytes), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
|
||||
|
||||
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
||||
let ndarray_data_i8 =
|
||||
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
||||
|
||||
// NOTE: Currently on `prehead_bb`
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
|
||||
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
let is_done = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
||||
.unwrap();
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
// Align the allocation to sizeof(T)
|
||||
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
|
||||
let alloc_ptr = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
llvm_elem_ty,
|
||||
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
|
||||
"rpc.alloc",
|
||||
)
|
||||
.unwrap();
|
||||
let alloc_ptr =
|
||||
ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
ndarray.as_base_value().into()
|
||||
}
|
||||
|
||||
_ => {
|
||||
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
|
||||
let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let is_done = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
|
||||
let alloc_ptr =
|
||||
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
||||
let alloc_ptr =
|
||||
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
ctx.builder.build_load(slot, "rpc.result").unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
@ -825,10 +441,10 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let ptr_type = ctx.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let int8 = ctx.ctx.i8_type();
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let ptr_type = int8.ptr_type(AddressSpace::default());
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
||||
@ -901,25 +517,22 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
mapping
|
||||
.remove(&arg.name)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, arg.ty)
|
||||
.map(|llvm_val| (llvm_val, arg.ty))
|
||||
})
|
||||
.collect::<Result<Vec<(_, _)>, _>>()?;
|
||||
.map(|arg| mapping.remove(&arg.name).unwrap().to_basic_value_enum(ctx, generator, arg.ty))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
if let Some(obj) = obj {
|
||||
if let ValueEnum::Static(obj_val) = obj.1 {
|
||||
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
|
||||
if let ValueEnum::Static(obj) = obj.1 {
|
||||
real_params.insert(0, obj.get_const_obj(ctx, generator));
|
||||
} else {
|
||||
// should be an error here...
|
||||
panic!("only host object is allowed");
|
||||
}
|
||||
}
|
||||
|
||||
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
|
||||
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
|
||||
for (i, arg) in real_params.iter().enumerate() {
|
||||
let arg_slot =
|
||||
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
|
||||
ctx.builder.build_store(arg_slot, *arg).unwrap();
|
||||
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
|
||||
let arg_ptr = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
args_ptr,
|
||||
@ -953,14 +566,63 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
// reclaim stack space used by arguments
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||
// -- receive value:
|
||||
// T result = {
|
||||
// void *ret_ptr = alloca(sizeof(T));
|
||||
// void *ptr = ret_ptr;
|
||||
// loop: int size = rpc_recv(ptr);
|
||||
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
||||
// if(size) { ptr = alloca(size); goto loop; }
|
||||
// else *(T*)ret_ptr
|
||||
// }
|
||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
|
||||
});
|
||||
|
||||
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||
// An RPC returning an NDArray would not touch here.
|
||||
call_stackrestore(ctx, stackptr);
|
||||
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
|
||||
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let current_function = prehead_bb.get_parent().unwrap();
|
||||
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
||||
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
||||
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
||||
|
||||
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
|
||||
let need_load = !ret_ty.is_pointer_type();
|
||||
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
|
||||
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let is_done = ctx
|
||||
.builder
|
||||
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
|
||||
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
|
||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
|
||||
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
|
||||
if need_load {
|
||||
call_stackrestore(ctx, stackptr);
|
||||
}
|
||||
Ok(Some(result))
|
||||
}
|
||||
|
||||
pub fn attributes_writeback(
|
||||
|
@ -448,6 +448,7 @@ impl Nac3 {
|
||||
pyid_to_type: pyid_to_type.clone(),
|
||||
primitive_ids: self.primitive_ids.clone(),
|
||||
global_value_ids: global_value_ids.clone(),
|
||||
class_names: Mutex::default(),
|
||||
name_to_pyid: name_to_pyid.clone(),
|
||||
module: module.clone(),
|
||||
id_to_pyval: RwLock::default(),
|
||||
@ -539,6 +540,7 @@ impl Nac3 {
|
||||
pyid_to_type: pyid_to_type.clone(),
|
||||
primitive_ids: self.primitive_ids.clone(),
|
||||
global_value_ids: global_value_ids.clone(),
|
||||
class_names: Mutex::default(),
|
||||
id_to_pyval: RwLock::default(),
|
||||
id_to_primitive: RwLock::default(),
|
||||
field_to_val: RwLock::default(),
|
||||
@ -555,10 +557,6 @@ impl Nac3 {
|
||||
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
|
||||
.unwrap();
|
||||
|
||||
// Process IRRT
|
||||
let context = inkwell::context::Context::create();
|
||||
let irrt = load_irrt(&context, resolver.as_ref());
|
||||
|
||||
let fun_signature =
|
||||
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
@ -729,7 +727,7 @@ impl Nac3 {
|
||||
membuffer.lock().push(buffer);
|
||||
});
|
||||
|
||||
// Link all modules into `main`.
|
||||
let context = inkwell::context::Context::create();
|
||||
let buffers = membuffers.lock();
|
||||
let main = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||
@ -758,7 +756,8 @@ impl Nac3 {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
main.link_in_module(load_irrt(&context))
|
||||
.map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
|
||||
let mut function_iter = main.get_first_function();
|
||||
while let Some(func) = function_iter {
|
||||
|
@ -1,6 +1,4 @@
|
||||
use crate::PrimitivePythonId;
|
||||
use inkwell::{
|
||||
module::Linkage,
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::BasicValueEnum,
|
||||
AddressSpace,
|
||||
@ -23,7 +21,7 @@ use nac3core::{
|
||||
},
|
||||
};
|
||||
use nac3parser::ast::{self, StrRef};
|
||||
use parking_lot::RwLock;
|
||||
use parking_lot::{Mutex, RwLock};
|
||||
use pyo3::{
|
||||
types::{PyDict, PyTuple},
|
||||
PyAny, PyObject, PyResult, Python,
|
||||
@ -36,6 +34,8 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use crate::PrimitivePythonId;
|
||||
|
||||
pub enum PrimitiveValue {
|
||||
I32(i32),
|
||||
I64(i64),
|
||||
@ -79,6 +79,7 @@ pub struct InnerResolver {
|
||||
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
|
||||
pub field_to_val: RwLock<HashMap<ResolverField, Option<PyFieldHandle>>>,
|
||||
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
|
||||
pub class_names: Mutex<HashMap<StrRef, Type>>,
|
||||
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
|
||||
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
|
||||
pub primitive_ids: PrimitivePythonId,
|
||||
@ -132,8 +133,6 @@ impl StaticValue for PythonValue {
|
||||
format!("{}_const", self.id).as_str(),
|
||||
);
|
||||
global.set_constant(true);
|
||||
// Set linkage of global to private to avoid name collisions
|
||||
global.set_linkage(Linkage::Private);
|
||||
global.set_initializer(&ctx.ctx.const_struct(
|
||||
&[ctx.ctx.i32_type().const_int(u64::from(id), false).into()],
|
||||
false,
|
||||
|
@ -14,8 +14,8 @@ indexmap = "2.2"
|
||||
parking_lot = "0.12"
|
||||
rayon = "1.8"
|
||||
nac3parser = { path = "../nac3parser" }
|
||||
strum = "0.26"
|
||||
strum_macros = "0.26"
|
||||
strum = "0.26.2"
|
||||
strum_macros = "0.26.4"
|
||||
|
||||
[dependencies.inkwell]
|
||||
version = "0.4"
|
||||
|
@ -8,50 +8,37 @@ use std::{
|
||||
};
|
||||
|
||||
fn main() {
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let out_dir = Path::new(&out_dir);
|
||||
let irrt_dir = Path::new("irrt");
|
||||
|
||||
let irrt_cpp_path = irrt_dir.join("irrt.cpp");
|
||||
const FILE: &str = "src/codegen/irrt/irrt.cpp";
|
||||
|
||||
/*
|
||||
* HACK: Sadly, clang doesn't let us emit generic LLVM bitcode.
|
||||
* Compiling for WASM32 and filtering the output with regex is the closest we can get.
|
||||
*/
|
||||
let mut flags: Vec<&str> = vec![
|
||||
let flags: &[&str] = &[
|
||||
"--target=wasm32",
|
||||
FILE,
|
||||
"-x",
|
||||
"c++",
|
||||
"-std=c++20",
|
||||
"-fno-discard-value-names",
|
||||
"-fno-exceptions",
|
||||
"-fno-rtti",
|
||||
match env::var("PROFILE").as_deref() {
|
||||
Ok("debug") => "-O0",
|
||||
Ok("release") => "-O3",
|
||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||
},
|
||||
"-emit-llvm",
|
||||
"-S",
|
||||
"-Wall",
|
||||
"-Wextra",
|
||||
"-o",
|
||||
"-",
|
||||
"-I",
|
||||
irrt_dir.to_str().unwrap(),
|
||||
irrt_cpp_path.to_str().unwrap(),
|
||||
];
|
||||
|
||||
match env::var("PROFILE").as_deref() {
|
||||
Ok("debug") => {
|
||||
flags.push("-O0");
|
||||
flags.push("-DIRRT_DEBUG_ASSERT");
|
||||
}
|
||||
Ok("release") => {
|
||||
flags.push("-O3");
|
||||
}
|
||||
flavor => panic!("Unknown or missing build flavor {flavor:?}"),
|
||||
}
|
||||
println!("cargo:rerun-if-changed={FILE}");
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let out_path = Path::new(&out_dir);
|
||||
|
||||
// Tell Cargo to rerun if any file under `irrt_dir` (recursive) changes
|
||||
println!("cargo:rerun-if-changed={}", irrt_dir.to_str().unwrap());
|
||||
|
||||
// Compile IRRT and capture the LLVM IR output
|
||||
let output = Command::new("clang-irrt")
|
||||
.args(flags)
|
||||
.output()
|
||||
@ -65,17 +52,7 @@ fn main() {
|
||||
let output = std::str::from_utf8(&output.stdout).unwrap().replace("\r\n", "\n");
|
||||
let mut filtered_output = String::with_capacity(output.len());
|
||||
|
||||
// Filter out irrelevant IR
|
||||
//
|
||||
// Regex:
|
||||
// - `(?ms:^define.*?\}$)` captures LLVM `define` blocks
|
||||
// - `(?m:^declare.*?$)` captures LLVM `declare` lines
|
||||
// - `(?m:^%.+?=\s*type\s*\{.+?\}$)` captures LLVM `type` declarations
|
||||
// - `(?m:^@.+?=.+$)` captures global constants
|
||||
let regex_filter = Regex::new(
|
||||
r"(?ms:^define.*?\}$)|(?m:^declare.*?$)|(?m:^%.+?=\s*type\s*\{.+?\}$)|(?m:^@.+?=.+$)",
|
||||
)
|
||||
.unwrap();
|
||||
let regex_filter = Regex::new(r"(?ms:^define.*?\}$)|(?m:^declare.*?$)").unwrap();
|
||||
for f in regex_filter.captures_iter(&output) {
|
||||
assert_eq!(f.len(), 1);
|
||||
filtered_output.push_str(&f[0]);
|
||||
@ -86,22 +63,18 @@ fn main() {
|
||||
.unwrap()
|
||||
.replace_all(&filtered_output, "");
|
||||
|
||||
// For debugging
|
||||
// Doing `DEBUG_DUMP_IRRT=1 cargo build -p nac3core` dumps the LLVM IR generated
|
||||
const DEBUG_DUMP_IRRT: &str = "DEBUG_DUMP_IRRT";
|
||||
println!("cargo:rerun-if-env-changed={DEBUG_DUMP_IRRT}");
|
||||
if env::var(DEBUG_DUMP_IRRT).is_ok() {
|
||||
let mut file = File::create(out_dir.join("irrt.ll")).unwrap();
|
||||
println!("cargo:rerun-if-env-changed=DEBUG_DUMP_IRRT");
|
||||
if env::var("DEBUG_DUMP_IRRT").is_ok() {
|
||||
let mut file = File::create(out_path.join("irrt.ll")).unwrap();
|
||||
file.write_all(output.as_bytes()).unwrap();
|
||||
|
||||
let mut file = File::create(out_dir.join("irrt-filtered.ll")).unwrap();
|
||||
let mut file = File::create(out_path.join("irrt-filtered.ll")).unwrap();
|
||||
file.write_all(filtered_output.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
let mut llvm_as = Command::new("llvm-as-irrt")
|
||||
.stdin(Stdio::piped())
|
||||
.arg("-o")
|
||||
.arg(out_dir.join("irrt.bc"))
|
||||
.arg(out_path.join("irrt.bc"))
|
||||
.spawn()
|
||||
.unwrap();
|
||||
llvm_as.stdin.as_mut().unwrap().write_all(filtered_output.as_bytes()).unwrap();
|
||||
|
@ -1,6 +0,0 @@
|
||||
#include "irrt/exception.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/list.hpp"
|
||||
#include "irrt/math.hpp"
|
||||
#include "irrt/ndarray.hpp"
|
||||
#include "irrt/slice.hpp"
|
@ -1,9 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
template<typename SizeT>
|
||||
struct CSlice {
|
||||
uint8_t* base;
|
||||
SizeT len;
|
||||
};
|
@ -1,25 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// Set in nac3core/build.rs
|
||||
#ifdef IRRT_DEBUG_ASSERT
|
||||
#define IRRT_DEBUG_ASSERT_BOOL true
|
||||
#else
|
||||
#define IRRT_DEBUG_ASSERT_BOOL false
|
||||
#endif
|
||||
|
||||
#define raise_debug_assert(SizeT, msg, param1, param2, param3) \
|
||||
raise_exception(SizeT, EXN_ASSERTION_ERROR, "IRRT debug assert failed: " msg, param1, param2, param3)
|
||||
|
||||
#define debug_assert_eq(SizeT, lhs, rhs) \
|
||||
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
|
||||
if ((lhs) != (rhs)) { \
|
||||
raise_debug_assert(SizeT, "LHS = {0}. RHS = {1}", lhs, rhs, NO_PARAM); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define debug_assert(SizeT, expr) \
|
||||
if constexpr (IRRT_DEBUG_ASSERT_BOOL) { \
|
||||
if (!(expr)) { \
|
||||
raise_debug_assert(SizeT, "Got false.", NO_PARAM, NO_PARAM, NO_PARAM); \
|
||||
} \
|
||||
}
|
@ -1,82 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/cslice.hpp"
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
/**
|
||||
* @brief The int type of ARTIQ exception IDs.
|
||||
*/
|
||||
typedef int32_t ExceptionId;
|
||||
|
||||
/*
|
||||
* Set of exceptions C++ IRRT can use.
|
||||
* Must be synchronized with `setup_irrt_exceptions` in `nac3core/src/codegen/irrt/mod.rs`.
|
||||
*/
|
||||
extern "C" {
|
||||
ExceptionId EXN_INDEX_ERROR;
|
||||
ExceptionId EXN_VALUE_ERROR;
|
||||
ExceptionId EXN_ASSERTION_ERROR;
|
||||
ExceptionId EXN_TYPE_ERROR;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Extern function to `__nac3_raise`
|
||||
*
|
||||
* The parameter `err` could be `Exception<int32_t>` or `Exception<int64_t>`. The caller
|
||||
* must make sure to pass `Exception`s with the correct `SizeT` depending on the `size_t` of the runtime.
|
||||
*/
|
||||
extern "C" void __nac3_raise(void* err);
|
||||
|
||||
namespace {
|
||||
/**
|
||||
* @brief NAC3's Exception struct
|
||||
*/
|
||||
template<typename SizeT>
|
||||
struct Exception {
|
||||
ExceptionId id;
|
||||
CSlice<SizeT> filename;
|
||||
int32_t line;
|
||||
int32_t column;
|
||||
CSlice<SizeT> function;
|
||||
CSlice<SizeT> msg;
|
||||
int64_t params[3];
|
||||
};
|
||||
|
||||
constexpr int64_t NO_PARAM = 0;
|
||||
|
||||
template<typename SizeT>
|
||||
void _raise_exception_helper(ExceptionId id,
|
||||
const char* filename,
|
||||
int32_t line,
|
||||
const char* function,
|
||||
const char* msg,
|
||||
int64_t param0,
|
||||
int64_t param1,
|
||||
int64_t param2) {
|
||||
Exception<SizeT> e = {
|
||||
.id = id,
|
||||
.filename = {.base = reinterpret_cast<const uint8_t*>(filename), .len = __builtin_strlen(filename)},
|
||||
.line = line,
|
||||
.column = 0,
|
||||
.function = {.base = reinterpret_cast<const uint8_t*>(function), .len = __builtin_strlen(function)},
|
||||
.msg = {.base = reinterpret_cast<const uint8_t*>(msg), .len = __builtin_strlen(msg)},
|
||||
};
|
||||
e.params[0] = param0;
|
||||
e.params[1] = param1;
|
||||
e.params[2] = param2;
|
||||
__nac3_raise(reinterpret_cast<void*>(&e));
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Raise an exception with location details (location in the IRRT source files).
|
||||
* @param SizeT The runtime `size_t` type.
|
||||
* @param id The ID of the exception to raise.
|
||||
* @param msg A global constant C-string of the error message.
|
||||
*
|
||||
* `param0` to `param2` are optional format arguments of `msg`. They should be set to
|
||||
* `NO_PARAM` to indicate they are unused.
|
||||
*/
|
||||
#define raise_exception(SizeT, id, msg, param0, param1, param2) \
|
||||
_raise_exception_helper<SizeT>(id, __FILE__, __LINE__, __FUNCTION__, msg, param0, param1, param2)
|
||||
} // namespace
|
@ -1,13 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
using int8_t = _BitInt(8);
|
||||
using uint8_t = unsigned _BitInt(8);
|
||||
using int32_t = _BitInt(32);
|
||||
using uint32_t = unsigned _BitInt(32);
|
||||
using int64_t = _BitInt(64);
|
||||
using uint64_t = unsigned _BitInt(64);
|
||||
|
||||
// NDArray indices are always `uint32_t`.
|
||||
using NDIndex = uint32_t;
|
||||
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
||||
using SliceIndex = int32_t;
|
@ -1,75 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
#include "irrt/math_util.hpp"
|
||||
|
||||
extern "C" {
|
||||
// Handle list assignment and dropping part of the list when
|
||||
// both dest_step and src_step are +1.
|
||||
// - All the index must *not* be out-of-bound or negative,
|
||||
// - The end index is *inclusive*,
|
||||
// - The length of src and dest slice size should already
|
||||
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
||||
SliceIndex __nac3_list_slice_assign_var_size(SliceIndex dest_start,
|
||||
SliceIndex dest_end,
|
||||
SliceIndex dest_step,
|
||||
uint8_t* dest_arr,
|
||||
SliceIndex dest_arr_len,
|
||||
SliceIndex src_start,
|
||||
SliceIndex src_end,
|
||||
SliceIndex src_step,
|
||||
uint8_t* src_arr,
|
||||
SliceIndex src_arr_len,
|
||||
const SliceIndex size) {
|
||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
||||
if (dest_arr_len == 0)
|
||||
return dest_arr_len;
|
||||
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
||||
if (src_step == dest_step && dest_step == 1) {
|
||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
||||
if (src_len > 0) {
|
||||
__builtin_memmove(dest_arr + dest_start * size, src_arr + src_start * size, src_len * size);
|
||||
}
|
||||
if (dest_len > 0) {
|
||||
/* dropping */
|
||||
__builtin_memmove(dest_arr + (dest_start + src_len) * size, dest_arr + (dest_end + 1) * size,
|
||||
(dest_arr_len - dest_end - 1) * size);
|
||||
}
|
||||
/* shrink size */
|
||||
return dest_arr_len - (dest_len - src_len);
|
||||
}
|
||||
/* if two range overlaps, need alloca */
|
||||
uint8_t need_alloca = (dest_arr == src_arr)
|
||||
&& !(max(dest_start, dest_end) < min(src_start, src_end)
|
||||
|| max(src_start, src_end) < min(dest_start, dest_end));
|
||||
if (need_alloca) {
|
||||
uint8_t* tmp = reinterpret_cast<uint8_t*>(__builtin_alloca(src_arr_len * size));
|
||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
||||
src_arr = tmp;
|
||||
}
|
||||
SliceIndex src_ind = src_start;
|
||||
SliceIndex dest_ind = dest_start;
|
||||
for (; (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); src_ind += src_step, dest_ind += dest_step) {
|
||||
/* for constant optimization */
|
||||
if (size == 1) {
|
||||
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
||||
} else if (size == 4) {
|
||||
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
||||
} else if (size == 8) {
|
||||
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
||||
} else {
|
||||
/* memcpy for var size, cannot overlap after previous alloca */
|
||||
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
||||
}
|
||||
}
|
||||
/* only dest_step == 1 can we shrink the dest list. */
|
||||
/* size should be ensured prior to calling this function */
|
||||
if (dest_step == 1 && dest_end >= dest_start) {
|
||||
__builtin_memmove(dest_arr + dest_ind * size, dest_arr + (dest_end + 1) * size,
|
||||
(dest_arr_len - dest_end - 1) * size);
|
||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
||||
}
|
||||
return dest_arr_len;
|
||||
}
|
||||
} // extern "C"
|
@ -1,93 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
namespace {
|
||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
// need to make sure `exp >= 0` before calling this function
|
||||
template<typename T>
|
||||
T __nac3_int_exp_impl(T base, T exp) {
|
||||
T res = 1;
|
||||
/* repeated squaring method */
|
||||
do {
|
||||
if (exp & 1) {
|
||||
res *= base; /* for n odd */
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
} while (exp);
|
||||
return res;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define DEF_nac3_int_exp_(T) \
|
||||
T __nac3_int_exp_##T(T base, T exp) { \
|
||||
return __nac3_int_exp_impl(base, exp); \
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
// Putting semicolons here to make clang-format not reformat this into
|
||||
// a stair shape.
|
||||
DEF_nac3_int_exp_(int32_t);
|
||||
DEF_nac3_int_exp_(int64_t);
|
||||
DEF_nac3_int_exp_(uint32_t);
|
||||
DEF_nac3_int_exp_(uint64_t);
|
||||
|
||||
int32_t __nac3_isinf(double x) {
|
||||
return __builtin_isinf(x);
|
||||
}
|
||||
|
||||
int32_t __nac3_isnan(double x) {
|
||||
return __builtin_isnan(x);
|
||||
}
|
||||
|
||||
double tgamma(double arg);
|
||||
|
||||
double __nac3_gamma(double z) {
|
||||
// Handling for denormals
|
||||
// | x | Python gamma(x) | C tgamma(x) |
|
||||
// --- | ----------------- | --------------- | ----------- |
|
||||
// (1) | nan | nan | nan |
|
||||
// (2) | -inf | -inf | inf |
|
||||
// (3) | inf | inf | inf |
|
||||
// (4) | 0.0 | inf | inf |
|
||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
||||
|
||||
// (1)-(3)
|
||||
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
||||
return z;
|
||||
}
|
||||
|
||||
double v = tgamma(z);
|
||||
|
||||
// (4)-(5)
|
||||
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
||||
}
|
||||
|
||||
double lgamma(double arg);
|
||||
|
||||
double __nac3_gammaln(double x) {
|
||||
// libm's handling of value overflows differs from scipy:
|
||||
// - scipy: gammaln(-inf) -> -inf
|
||||
// - libm : lgamma(-inf) -> inf
|
||||
|
||||
if (__builtin_isinf(x)) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return lgamma(x);
|
||||
}
|
||||
|
||||
double j0(double x);
|
||||
|
||||
double __nac3_j0(double x) {
|
||||
// libm's handling of value overflows differs from scipy:
|
||||
// - scipy: j0(inf) -> nan
|
||||
// - libm : j0(inf) -> 0.0
|
||||
|
||||
if (__builtin_isinf(x)) {
|
||||
return __builtin_nan("");
|
||||
}
|
||||
|
||||
return j0(x);
|
||||
}
|
||||
}
|
@ -1,13 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
namespace {
|
||||
template<typename T>
|
||||
const T& max(const T& a, const T& b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T& min(const T& a, const T& b) {
|
||||
return a > b ? b : a;
|
||||
}
|
||||
} // namespace
|
@ -1,144 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
namespace {
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) {
|
||||
__builtin_assume(end_idx <= list_len);
|
||||
|
||||
SizeT num_elems = 1;
|
||||
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||
SizeT val = list_data[i];
|
||||
__builtin_assume(val > 0);
|
||||
num_elems *= val;
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndex* idxs) {
|
||||
SizeT stride = 1;
|
||||
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||
SizeT i = num_dims - dim - 1;
|
||||
__builtin_assume(dims[i] > 0);
|
||||
idxs[i] = (index / stride) % dims[i];
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, SizeT num_dims, const NDIndex* indices, SizeT num_indices) {
|
||||
SizeT idx = 0;
|
||||
SizeT stride = 1;
|
||||
for (SizeT i = 0; i < num_dims; ++i) {
|
||||
SizeT ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += stride * indices[ri];
|
||||
}
|
||||
|
||||
__builtin_assume(dims[i] > 0);
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
||||
SizeT lhs_ndims,
|
||||
const SizeT* rhs_dims,
|
||||
SizeT rhs_ndims,
|
||||
SizeT* out_dims) {
|
||||
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == nullptr) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == nullptr) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims,
|
||||
SizeT src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx) {
|
||||
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||
SizeT src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
uint64_t
|
||||
__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndex* idxs) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndex* idxs) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
uint32_t
|
||||
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndex* indices, uint32_t num_indices) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
uint64_t
|
||||
__nac3_ndarray_flatten_index64(const uint64_t* dims, uint64_t num_dims, const NDIndex* indices, uint64_t num_indices) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
||||
uint32_t lhs_ndims,
|
||||
const uint32_t* rhs_dims,
|
||||
uint32_t rhs_ndims,
|
||||
uint32_t* out_dims) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims,
|
||||
uint64_t lhs_ndims,
|
||||
const uint64_t* rhs_dims,
|
||||
uint64_t rhs_ndims,
|
||||
uint64_t* out_dims) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims,
|
||||
uint32_t src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims,
|
||||
uint64_t src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
}
|
@ -1,28 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "irrt/int_types.hpp"
|
||||
|
||||
extern "C" {
|
||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||
if (i < 0) {
|
||||
i = len + i;
|
||||
}
|
||||
if (i < 0) {
|
||||
return 0;
|
||||
} else if (i > len) {
|
||||
return len;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
SliceIndex __nac3_range_slice_len(const SliceIndex start, const SliceIndex end, const SliceIndex step) {
|
||||
SliceIndex diff = end - start;
|
||||
if (diff > 0 && step > 0) {
|
||||
return ((diff - 1) / step) + 1;
|
||||
} else if (diff < 0 && step < 0) {
|
||||
return ((diff + 1) / step) + 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
@ -9,7 +9,6 @@ use crate::codegen::classes::{
|
||||
};
|
||||
use crate::codegen::expr::destructure_range;
|
||||
use crate::codegen::irrt::calculate_len_for_slice_range;
|
||||
use crate::codegen::macros::codegen_unreachable;
|
||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
@ -21,8 +20,7 @@ use crate::typecheck::typedef::{Type, TypeEnum};
|
||||
///
|
||||
/// The generated message will contain the function name and the name of the unsupported type.
|
||||
fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -> ! {
|
||||
codegen_unreachable!(
|
||||
ctx,
|
||||
unreachable!(
|
||||
"{fn_name}() not supported for '{}'",
|
||||
tys.iter().map(|ty| format!("'{}'", ctx.unifier.stringify(*ty))).join(", "),
|
||||
)
|
||||
@ -84,7 +82,7 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -786,7 +784,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
@ -890,7 +888,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
match fn_name {
|
||||
"np_argmin" | "np_argmax" => llvm_int64.const_zero().into(),
|
||||
"np_max" | "np_min" => a,
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
BasicValueEnum::PointerValue(n)
|
||||
@ -945,7 +943,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
"np_argmax" | "np_max" => {
|
||||
call_max(ctx, (elem_ty, accumulator), (elem_ty, elem))
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let updated_idx = match (accumulator, result) {
|
||||
@ -982,7 +980,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
match fn_name {
|
||||
"np_argmin" | "np_argmax" => ctx.builder.build_load(res_idx, "").unwrap(),
|
||||
"np_max" | "np_min" => ctx.builder.build_load(accumulator_addr, "").unwrap(),
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -1048,7 +1046,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
@ -1078,9 +1076,9 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// * `(arg_ty, arg_val)`: The [`Type`] and llvm value of the input argument.
|
||||
/// * `fn_name`: The name of the function, only used when throwing an error with [`unsupported_type`]
|
||||
/// * `get_ret_elem_type`: A function that takes in the input scalar [`Type`], and returns the function's return scalar [`Type`].
|
||||
/// Return a constant [`Type`] here if the return type does not depend on the input type.
|
||||
/// Return a constant [`Type`] here if the return type does not depend on the input type.
|
||||
/// * `on_scalar`: The function that acts on the scalars of the input. Returns [`Option::None`]
|
||||
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
|
||||
/// if the scalar type & value are faulty and should panic with [`unsupported_type`].
|
||||
fn helper_call_numpy_unary_elementwise<'ctx, OnScalarFn, RetElemFn, G>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -1191,9 +1189,9 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`]
|
||||
/// * `$get_ret_elem_type:expr`: To be passed to the `get_ret_elem_type` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// But there is no need to make it a reference.
|
||||
/// But there is no need to make it a reference.
|
||||
/// * `$on_scalar:expr`: To be passed to the `on_scalar` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// But there is no need to make it a reference.
|
||||
/// But there is no need to make it a reference.
|
||||
macro_rules! create_helper_call_numpy_unary_elementwise {
|
||||
($name:ident, $fn_name:literal, $get_ret_elem_type:expr, $on_scalar:expr) => {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
@ -1220,7 +1218,7 @@ macro_rules! create_helper_call_numpy_unary_elementwise {
|
||||
/// * `$name:ident`: The identifier of the rust function to be generated.
|
||||
/// * `$fn_name:literal`: To be passed to the `fn_name` parameter of [`helper_call_numpy_unary_elementwise`].
|
||||
/// * `$on_scalar:expr`: The closure (see below for its type) that acts on float scalar values and returns
|
||||
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
|
||||
/// the boolean results of LLVM type `i1`. The returned `i1` value will be converted into an `i8`.
|
||||
///
|
||||
/// ```ignore
|
||||
/// // Type of `$on_scalar:expr`
|
||||
@ -1488,7 +1486,7 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
@ -1555,7 +1553,7 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
@ -1622,7 +1620,7 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
@ -1689,7 +1687,7 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
@ -1812,7 +1810,7 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
@ -1879,7 +1877,7 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if is_ndarray2 {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty).0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let x1_scalar_ty = if is_ndarray1 { dtype } else { x1_ty };
|
||||
|
@ -1404,7 +1404,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
|
||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||
/// on the field.
|
||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
||||
|
||||
|
@ -9,9 +9,8 @@ use crate::{
|
||||
irrt::*,
|
||||
llvm_intrinsics::{
|
||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||
call_int_umin, call_memcpy_generic,
|
||||
call_memcpy_generic,
|
||||
},
|
||||
macros::codegen_unreachable,
|
||||
need_sret, numpy,
|
||||
stmt::{
|
||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||
@ -41,7 +40,6 @@ use nac3parser::ast::{
|
||||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
Unaryop,
|
||||
};
|
||||
use std::cmp::min;
|
||||
use std::iter::{repeat, repeat_with};
|
||||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||
|
||||
@ -113,7 +111,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
let obj_id = match &*self.unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
||||
// we cannot have other types, virtual type should be handled by function calls
|
||||
_ => codegen_unreachable!(self),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let def = &self.top_level.definitions.read()[obj_id.0];
|
||||
let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
|
||||
@ -124,7 +122,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
(attribute_index.0, Some(attribute_index.1 .2.clone()))
|
||||
}
|
||||
} else {
|
||||
codegen_unreachable!(self)
|
||||
unreachable!()
|
||||
};
|
||||
(index, value)
|
||||
}
|
||||
@ -134,7 +132,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
TypeEnum::TObj { fields, .. } => {
|
||||
fields.iter().find_position(|x| *x.0 == attr).unwrap().0
|
||||
}
|
||||
_ => codegen_unreachable!(self),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,7 +187,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
}
|
||||
_ => codegen_unreachable!(self, "must be option type"),
|
||||
_ => unreachable!("must be option type"),
|
||||
};
|
||||
let val = self.gen_symbol_val(generator, v, ty);
|
||||
let ptr = generator
|
||||
@ -205,7 +203,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
}
|
||||
_ => codegen_unreachable!(self, "must be option type"),
|
||||
_ => unreachable!("must be option type"),
|
||||
};
|
||||
let actual_ptr_type =
|
||||
self.get_llvm_type(generator, ty).ptr_type(AddressSpace::default());
|
||||
@ -272,7 +270,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
{
|
||||
self.ctx.i64_type()
|
||||
} else {
|
||||
codegen_unreachable!(self)
|
||||
unreachable!()
|
||||
};
|
||||
Some(ty.const_int(*val as u64, false).into())
|
||||
}
|
||||
@ -286,7 +284,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
let (types, is_vararg_ctx) = if let TypeEnum::TTuple { ty, is_vararg_ctx } = &*ty {
|
||||
(ty.clone(), *is_vararg_ctx)
|
||||
} else {
|
||||
codegen_unreachable!(self)
|
||||
unreachable!()
|
||||
};
|
||||
let values = zip(types, v.iter())
|
||||
.map_while(|(ty, v)| self.gen_const(generator, v, ty))
|
||||
@ -331,7 +329,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
|
||||
None
|
||||
}
|
||||
_ => codegen_unreachable!(self),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -345,7 +343,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
signed: bool,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) else {
|
||||
codegen_unreachable!(self)
|
||||
unreachable!()
|
||||
};
|
||||
let float = self.ctx.f64_type();
|
||||
match (op, signed) {
|
||||
@ -420,7 +418,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
.build_right_shift(lhs, rhs, signed, "rshift")
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
_ => codegen_unreachable!(self),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -432,7 +430,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
}
|
||||
(Operator::Pow, s) => integer_power(generator, self, lhs, rhs, s).into(),
|
||||
// special implementation?
|
||||
(Operator::MatMult, _) => codegen_unreachable!(self),
|
||||
(Operator::MatMult, _) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -444,8 +442,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||
rhs: BasicValueEnum<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else {
|
||||
codegen_unreachable!(
|
||||
self,
|
||||
unreachable!(
|
||||
"Expected (FloatValue, FloatValue), got ({}, {})",
|
||||
lhs.get_type(),
|
||||
rhs.get_type()
|
||||
@ -689,7 +686,7 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>(
|
||||
def: &TopLevelDef,
|
||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let TopLevelDef::Class { methods, .. } = def else { codegen_unreachable!(ctx) };
|
||||
let TopLevelDef::Class { methods, .. } = def else { unreachable!() };
|
||||
|
||||
// TODO: what about other fields that require alloca?
|
||||
let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2);
|
||||
@ -721,7 +718,7 @@ pub fn gen_func_instance<'ctx>(
|
||||
key,
|
||||
) = fun
|
||||
else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if let Some(sym) = instance_to_symbol.get(&key) {
|
||||
@ -753,7 +750,7 @@ pub fn gen_func_instance<'ctx>(
|
||||
.collect();
|
||||
|
||||
let mut signature = store.from_signature(&mut ctx.unifier, &ctx.primitives, sign, &mut cache);
|
||||
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { codegen_unreachable!(ctx) };
|
||||
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else { unreachable!() };
|
||||
|
||||
if let Some(obj) = &obj {
|
||||
let zelf = store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
|
||||
@ -1119,7 +1116,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
expr: &Expr<Option<Type>>,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let ExprKind::ListComp { elt, generators } = &expr.node else { codegen_unreachable!(ctx) };
|
||||
let ExprKind::ListComp { elt, generators } = &expr.node else { unreachable!() };
|
||||
|
||||
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
|
||||
|
||||
@ -1378,13 +1375,13 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) {
|
||||
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
let elem_ty2 =
|
||||
if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) {
|
||||
ctx.unifier.get_representative(*params.iter().next().unwrap().1)
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
debug_assert!(ctx.unifier.unioned(elem_ty1, elem_ty2));
|
||||
|
||||
@ -1457,7 +1454,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
(elem_ty, left_val, right_val)
|
||||
@ -1467,12 +1464,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
(elem_ty, right_val, left_val)
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
let list_val =
|
||||
ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None);
|
||||
@ -1639,7 +1636,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
} else {
|
||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
|
||||
codegen_unreachable!(ctx, "must be tobj")
|
||||
unreachable!("must be tobj")
|
||||
};
|
||||
let (op_name, id) = {
|
||||
let normal_method_name = Binop::normal(op.base).op_info().method_name;
|
||||
@ -1660,19 +1657,19 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
} else {
|
||||
let left_enum_ty = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
|
||||
codegen_unreachable!(ctx, "must be tobj")
|
||||
unreachable!("must be tobj")
|
||||
};
|
||||
|
||||
let fn_ty = fields.get(&op_name).unwrap().0;
|
||||
let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty);
|
||||
let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { codegen_unreachable!(ctx) };
|
||||
let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else { unreachable!() };
|
||||
|
||||
sig.clone()
|
||||
};
|
||||
let fun_id = {
|
||||
let defs = ctx.top_level.definitions.read();
|
||||
let obj_def = defs.get(id.0).unwrap().read();
|
||||
let TopLevelDef::Class { methods, .. } = &*obj_def else { codegen_unreachable!(ctx) };
|
||||
let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() };
|
||||
|
||||
methods.iter().find(|method| method.0 == op_name).unwrap().2
|
||||
};
|
||||
@ -1803,8 +1800,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
if op == ast::Unaryop::Invert {
|
||||
ast::Unaryop::Not
|
||||
} else {
|
||||
codegen_unreachable!(
|
||||
ctx,
|
||||
unreachable!(
|
||||
"ufunc {} not supported for ndarray[bool, N]",
|
||||
op.op_info().method_name,
|
||||
)
|
||||
@ -1871,8 +1867,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (Some(left_ty), lhs) = left else { codegen_unreachable!(ctx) };
|
||||
let (Some(right_ty), rhs) = comparators[0] else { codegen_unreachable!(ctx) };
|
||||
let (Some(left_ty), lhs) = left else { unreachable!() };
|
||||
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
|
||||
let op = ops[0];
|
||||
|
||||
let is_ndarray1 =
|
||||
@ -1979,7 +1975,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ,
|
||||
ast::Cmpop::NotEq => IntPredicate::NE,
|
||||
_ if left_ty == ctx.primitives.bool => codegen_unreachable!(ctx),
|
||||
_ if left_ty == ctx.primitives.bool => unreachable!(),
|
||||
ast::Cmpop::Lt => {
|
||||
if use_unsigned_ops {
|
||||
IntPredicate::ULT
|
||||
@ -2008,7 +2004,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
IntPredicate::SGE
|
||||
}
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap()
|
||||
@ -2025,118 +2021,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
|
||||
} else if left_ty == ctx.primitives.str {
|
||||
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let lhs = lhs.into_struct_value();
|
||||
let rhs = rhs.into_struct_value();
|
||||
|
||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||
ctx.builder.build_store(plhs, lhs).unwrap();
|
||||
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||
ctx.builder.build_store(prhs, rhs).unwrap();
|
||||
|
||||
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
||||
plhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||
None,
|
||||
).into_int_value();
|
||||
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
||||
prhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
|
||||
ctx.builder.position_at_end(current_bb);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(len, false),
|
||||
|generator, ctx, _, i| {
|
||||
let lhs_char = {
|
||||
let plhs_data = ctx.build_in_bounds_gep_and_load(
|
||||
plhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
plhs_data,
|
||||
&[i],
|
||||
None
|
||||
).into_int_value()
|
||||
};
|
||||
let rhs_char = {
|
||||
let prhs_data = ctx.build_in_bounds_gep_and_load(
|
||||
prhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
prhs_data,
|
||||
&[i],
|
||||
None
|
||||
).into_int_value()
|
||||
};
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
|
||||
},
|
||||
|_, ctx| {
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
let is_len_eq = ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
lhs_len,
|
||||
rhs_len,
|
||||
"",
|
||||
).unwrap();
|
||||
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
|
||||
|
||||
// Invert the final value if __ne__
|
||||
if *op == Cmpop::NotEq {
|
||||
ctx.builder.build_not(cmp_phi, "").unwrap()
|
||||
} else {
|
||||
cmp_phi
|
||||
}
|
||||
} else if [left_ty, right_ty]
|
||||
.iter()
|
||||
.any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()))
|
||||
@ -2157,7 +2044,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
match (op, val) {
|
||||
(Cmpop::Eq, true) | (Cmpop::NotEq, false) => llvm_i1.const_all_ones(),
|
||||
(Cmpop::Eq, false) | (Cmpop::NotEq, true) => llvm_i1.const_zero(),
|
||||
(_, _) => codegen_unreachable!(ctx),
|
||||
(_, _) => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
@ -2170,14 +2057,14 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
let right_elem_ty = if let TypeEnum::TObj { params, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(right_ty)
|
||||
{
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if !ctx.unifier.unioned(left_elem_ty, right_elem_ty) {
|
||||
@ -2307,124 +2194,8 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
};
|
||||
|
||||
gen_list_cmpop(generator, ctx)?
|
||||
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) {
|
||||
let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else {
|
||||
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
|
||||
};
|
||||
let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else {
|
||||
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
|
||||
};
|
||||
|
||||
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
|
||||
todo!("Only __eq__ and __ne__ is implemented for tuples")
|
||||
}
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
// Assume `true` by default
|
||||
let cmp_addr = generator.gen_var_alloc(ctx, llvm_i1.into(), None).unwrap();
|
||||
ctx.builder.build_store(cmp_addr, llvm_i1.const_all_ones()).unwrap();
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
|
||||
ctx.builder.position_at_end(current_bb);
|
||||
|
||||
// Generate comparison between each element
|
||||
let min_len = min(left_tys.len(), right_tys.len());
|
||||
for i in 0..min_len {
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("foreach.cmp.tuple.{i}e"));
|
||||
ctx.builder.build_unconditional_branch(bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(bb);
|
||||
let left_ty = left_tys[i];
|
||||
let left_elem = {
|
||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type(), None).unwrap();
|
||||
ctx.builder.build_store(plhs, *lhs).unwrap();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
plhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)],
|
||||
None,
|
||||
)
|
||||
};
|
||||
let right_ty = right_tys[i];
|
||||
let right_elem = {
|
||||
let prhs = generator.gen_var_alloc(ctx, rhs.get_type(), None).unwrap();
|
||||
ctx.builder.build_store(prhs, *rhs).unwrap();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
prhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)],
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
// Defer the `not` operation until the end - a != b <=> !(a == b)
|
||||
let op = if *op == Cmpop::NotEq { Cmpop::Eq } else { *op };
|
||||
|
||||
let cmp = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(left_ty), left_elem),
|
||||
&[op],
|
||||
&[(Some(right_ty), right_elem)],
|
||||
)
|
||||
.transpose()
|
||||
.unwrap()
|
||||
.and_then(|v| {
|
||||
v.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
||||
})
|
||||
.map(BasicValueEnum::into_int_value)?;
|
||||
|
||||
Ok(ctx.builder.build_not(
|
||||
generator.bool_to_i1(ctx, cmp),
|
||||
"",
|
||||
).unwrap())
|
||||
},
|
||||
|_, ctx| {
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Length of tuples is checked last as operators do not short-circuit by tuple
|
||||
// length in Python:
|
||||
//
|
||||
// >>> (1, 2) < ("a",)
|
||||
// TypeError: '<' not supported between instances of 'int' and 'str'
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
let is_len_eq = llvm_i1.const_int(
|
||||
u64::from(left_tys.len() == right_tys.len()),
|
||||
false,
|
||||
);
|
||||
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
|
||||
|
||||
// Invert the final value if __ne__
|
||||
if *op == Cmpop::NotEq {
|
||||
ctx.builder.build_not(cmp_phi, "").unwrap()
|
||||
} else {
|
||||
cmp_phi
|
||||
}
|
||||
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) {
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
if ctx.registry.llvm_options.opt_level != OptimizationLevel::None {
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.ctx.bool_type().const_all_ones(),
|
||||
@ -2437,10 +2208,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
|
||||
ctx.ctx.bool_type().get_poison()
|
||||
} else {
|
||||
return Err(format!("'{}' not supported between instances of '{}' and '{}'",
|
||||
op.op_info().symbol,
|
||||
ctx.unifier.stringify(left_ty),
|
||||
ctx.unifier.stringify(right_ty)))
|
||||
unimplemented!()
|
||||
};
|
||||
|
||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||
@ -2517,7 +2285,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let ndims = values
|
||||
@ -2869,7 +2637,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
.const_null()
|
||||
.into()
|
||||
}
|
||||
_ => codegen_unreachable!(ctx, "must be option type"),
|
||||
_ => unreachable!("must be option type"),
|
||||
}
|
||||
}
|
||||
ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) {
|
||||
@ -2879,7 +2647,29 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
|
||||
None => {
|
||||
let resolver = ctx.resolver.clone();
|
||||
resolver.get_symbol_value(*id, ctx).unwrap()
|
||||
if let Some(res) = resolver.get_symbol_value(*id, ctx) {
|
||||
res
|
||||
} else {
|
||||
// Allow "raise Exception" short form
|
||||
let def_id = resolver.get_identifier_def(*id).map_err(|e| {
|
||||
format!("{} (at {})", e.iter().next().unwrap(), expr.location)
|
||||
})?;
|
||||
let def = ctx.top_level.definitions.read();
|
||||
if let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() {
|
||||
let TypeEnum::TFunc(signature) =
|
||||
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
|
||||
else {
|
||||
return Err(format!(
|
||||
"Failed to resolve symbol {} (at {})",
|
||||
id, expr.location
|
||||
));
|
||||
};
|
||||
return Ok(generator
|
||||
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
|
||||
.map(Into::into));
|
||||
}
|
||||
return Err(format!("Failed to resolve symbol {} (at {})", id, expr.location));
|
||||
}
|
||||
}
|
||||
},
|
||||
ExprKind::List { elts, .. } => {
|
||||
@ -2908,7 +2698,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
|
||||
*params.iter().next().unwrap().1
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if let TypeEnum::TVar { .. } = &*ctx.unifier.get_ty_immutable(ty) {
|
||||
@ -3002,9 +2792,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
|
||||
return generator.gen_expr(ctx, &modified_expr);
|
||||
}
|
||||
None => {
|
||||
codegen_unreachable!(ctx, "Function Type should not have attributes")
|
||||
}
|
||||
None => unreachable!("Function Type should not have attributes"),
|
||||
}
|
||||
} else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) {
|
||||
if fields.is_empty() && params.is_empty() {
|
||||
@ -3026,7 +2814,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
|
||||
return generator.gen_expr(ctx, &modified_expr);
|
||||
}
|
||||
None => codegen_unreachable!(ctx),
|
||||
None => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3128,7 +2916,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
(Some(a), None) => a.into(),
|
||||
(None, Some(b)) => b.into(),
|
||||
(None, None) => codegen_unreachable!(ctx),
|
||||
(None, None) => unreachable!(),
|
||||
}
|
||||
}
|
||||
ExprKind::BinOp { op, left, right } => {
|
||||
@ -3194,6 +2982,29 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
}
|
||||
ExprKind::Call { func, args, keywords } => {
|
||||
// Check if call is to a parent method
|
||||
let mut is_override = false;
|
||||
if let Some(arg) = args.last() {
|
||||
if let ExprKind::Name { id, .. } = arg.node {
|
||||
if id == "self".into() {
|
||||
is_override = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut args = args.clone();
|
||||
let (zelf, func_id) = if is_override {
|
||||
let zelf = args.pop();
|
||||
let ExprKind::Constant { value: ast::Constant::Int(func_id), .. } =
|
||||
args.pop().unwrap().node
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
(zelf, Some(func_id))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let mut params = args
|
||||
.iter()
|
||||
.map(|arg| generator.gen_expr(ctx, arg))
|
||||
@ -3218,9 +3029,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
ctx.unifier.get_call_signature(*call).unwrap()
|
||||
} else {
|
||||
let ty = func.custom.unwrap();
|
||||
let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else {
|
||||
codegen_unreachable!(ctx)
|
||||
};
|
||||
let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else { unreachable!() };
|
||||
|
||||
sign.clone()
|
||||
};
|
||||
@ -3239,26 +3048,24 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
|
||||
|
||||
// Handle Class Method calls
|
||||
// The attribute will be `DefinitionId` of the method if the call is to one of the parent methods
|
||||
let func_id = attr.to_string().parse::<usize>();
|
||||
|
||||
let id = if let TypeEnum::TObj { obj_id, .. } =
|
||||
&*ctx.unifier.get_ty(value.custom.unwrap())
|
||||
{
|
||||
let class_ty = if is_override {
|
||||
zelf.unwrap().custom.unwrap()
|
||||
} else {
|
||||
value.custom.unwrap()
|
||||
};
|
||||
let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(class_ty) {
|
||||
*obj_id
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
// Use the `DefinitionID` from attribute if it is available
|
||||
let fun_id = if let Ok(func_id) = func_id {
|
||||
DefinitionId(func_id)
|
||||
// Get function definition
|
||||
let fun_id = if is_override {
|
||||
DefinitionId(func_id.unwrap() as usize)
|
||||
} else {
|
||||
let defs = ctx.top_level.definitions.read();
|
||||
let obj_def = defs.get(id.0).unwrap().read();
|
||||
let TopLevelDef::Class { methods, .. } = &*obj_def else {
|
||||
codegen_unreachable!(ctx)
|
||||
};
|
||||
let TopLevelDef::Class { methods, .. } = &*obj_def else { unreachable!() };
|
||||
|
||||
methods.iter().find(|method| method.0 == *attr).unwrap().2
|
||||
};
|
||||
@ -3329,9 +3136,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
.unwrap(),
|
||||
));
|
||||
}
|
||||
ValueEnum::Dynamic(_) => {
|
||||
codegen_unreachable!(ctx, "option must be static or ptr")
|
||||
}
|
||||
ValueEnum::Dynamic(_) => unreachable!("option must be static or ptr"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -3480,10 +3285,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node {
|
||||
(*v).try_into().unwrap()
|
||||
} else {
|
||||
codegen_unreachable!(
|
||||
ctx,
|
||||
"tuple subscript must be const int after type check"
|
||||
);
|
||||
unreachable!("tuple subscript must be const int after type check");
|
||||
};
|
||||
match generator.gen_expr(ctx, value)? {
|
||||
Some(ValueEnum::Dynamic(v)) => {
|
||||
@ -3506,10 +3308,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
None => return Ok(None),
|
||||
}
|
||||
}
|
||||
_ => codegen_unreachable!(
|
||||
ctx,
|
||||
"should not be other subscriptable types after type check"
|
||||
),
|
||||
_ => unreachable!("should not be other subscriptable types after type check"),
|
||||
}
|
||||
}
|
||||
ExprKind::ListComp { .. } => {
|
||||
|
@ -13,11 +13,11 @@ use crate::codegen::CodeGenContext;
|
||||
/// * `$extern_fn:literal`: Name of underlying extern function
|
||||
///
|
||||
/// Optional Arguments:
|
||||
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function.
|
||||
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly".
|
||||
/// These will be used unless other attributes are specified
|
||||
/// * `$(,$attributes:literal)*)`: Attributes linked with the extern function
|
||||
/// The default attributes are "mustprogress", "nofree", "nounwind", "willreturn", and "writeonly"
|
||||
/// These will be used unless other attributes are specified
|
||||
/// * `$(,$args:ident)*`: Operands of the extern function
|
||||
/// The data type of these operands will be set to `FloatValue`
|
||||
/// The data type of these operands will be set to `FloatValue`
|
||||
///
|
||||
macro_rules! generate_extern_fn {
|
||||
("unary", $fn_name:ident, $extern_fn:literal) => {
|
||||
|
@ -57,7 +57,6 @@ pub trait CodeGenerator {
|
||||
/// - fun: Function signature, definition ID and the substitution key.
|
||||
/// - params: Function parameters. Note that this does not include the object even if the
|
||||
/// function is a class method.
|
||||
///
|
||||
/// Note that this function should check if the function is generated in another thread (due to
|
||||
/// possible race condition), see the default implementation for an example.
|
||||
fn gen_func_instance<'ctx>(
|
||||
|
414
nac3core/src/codegen/irrt/irrt.cpp
Normal file
414
nac3core/src/codegen/irrt/irrt.cpp
Normal file
@ -0,0 +1,414 @@
|
||||
using int8_t = _BitInt(8);
|
||||
using uint8_t = unsigned _BitInt(8);
|
||||
using int32_t = _BitInt(32);
|
||||
using uint32_t = unsigned _BitInt(32);
|
||||
using int64_t = _BitInt(64);
|
||||
using uint64_t = unsigned _BitInt(64);
|
||||
|
||||
// NDArray indices are always `uint32_t`.
|
||||
using NDIndex = uint32_t;
|
||||
// The type of an index or a value describing the length of a range/slice is always `int32_t`.
|
||||
using SliceIndex = int32_t;
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
const T& max(const T& a, const T& b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
const T& min(const T& a, const T& b) {
|
||||
return a > b ? b : a;
|
||||
}
|
||||
|
||||
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
|
||||
// need to make sure `exp >= 0` before calling this function
|
||||
template <typename T>
|
||||
T __nac3_int_exp_impl(T base, T exp) {
|
||||
T res = 1;
|
||||
/* repeated squaring method */
|
||||
do {
|
||||
if (exp & 1) {
|
||||
res *= base; /* for n odd */
|
||||
}
|
||||
exp >>= 1;
|
||||
base *= base;
|
||||
} while (exp);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
SizeT __nac3_ndarray_calc_size_impl(
|
||||
const SizeT* list_data,
|
||||
SizeT list_len,
|
||||
SizeT begin_idx,
|
||||
SizeT end_idx
|
||||
) {
|
||||
__builtin_assume(end_idx <= list_len);
|
||||
|
||||
SizeT num_elems = 1;
|
||||
for (SizeT i = begin_idx; i < end_idx; ++i) {
|
||||
SizeT val = list_data[i];
|
||||
__builtin_assume(val > 0);
|
||||
num_elems *= val;
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
void __nac3_ndarray_calc_nd_indices_impl(
|
||||
SizeT index,
|
||||
const SizeT* dims,
|
||||
SizeT num_dims,
|
||||
NDIndex* idxs
|
||||
) {
|
||||
SizeT stride = 1;
|
||||
for (SizeT dim = 0; dim < num_dims; dim++) {
|
||||
SizeT i = num_dims - dim - 1;
|
||||
__builtin_assume(dims[i] > 0);
|
||||
idxs[i] = (index / stride) % dims[i];
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
SizeT __nac3_ndarray_flatten_index_impl(
|
||||
const SizeT* dims,
|
||||
SizeT num_dims,
|
||||
const NDIndex* indices,
|
||||
SizeT num_indices
|
||||
) {
|
||||
SizeT idx = 0;
|
||||
SizeT stride = 1;
|
||||
for (SizeT i = 0; i < num_dims; ++i) {
|
||||
SizeT ri = num_dims - i - 1;
|
||||
if (ri < num_indices) {
|
||||
idx += stride * indices[ri];
|
||||
}
|
||||
|
||||
__builtin_assume(dims[i] > 0);
|
||||
stride *= dims[ri];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_impl(
|
||||
const SizeT* lhs_dims,
|
||||
SizeT lhs_ndims,
|
||||
const SizeT* rhs_dims,
|
||||
SizeT rhs_ndims,
|
||||
SizeT* out_dims
|
||||
) {
|
||||
SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
|
||||
|
||||
for (SizeT i = 0; i < max_ndims; ++i) {
|
||||
const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr;
|
||||
const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr;
|
||||
SizeT* out_dim = &out_dims[max_ndims - i - 1];
|
||||
|
||||
if (lhs_dim_sz == nullptr) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (rhs_dim_sz == nullptr) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == 1) {
|
||||
*out_dim = *rhs_dim_sz;
|
||||
} else if (*rhs_dim_sz == 1) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else if (*lhs_dim_sz == *rhs_dim_sz) {
|
||||
*out_dim = *lhs_dim_sz;
|
||||
} else {
|
||||
__builtin_unreachable();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SizeT>
|
||||
void __nac3_ndarray_calc_broadcast_idx_impl(
|
||||
const SizeT* src_dims,
|
||||
SizeT src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx
|
||||
) {
|
||||
for (SizeT i = 0; i < src_ndims; ++i) {
|
||||
SizeT src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
#define DEF_nac3_int_exp_(T) \
|
||||
T __nac3_int_exp_##T(T base, T exp) {\
|
||||
return __nac3_int_exp_impl(base, exp);\
|
||||
}
|
||||
|
||||
DEF_nac3_int_exp_(int32_t)
|
||||
DEF_nac3_int_exp_(int64_t)
|
||||
DEF_nac3_int_exp_(uint32_t)
|
||||
DEF_nac3_int_exp_(uint64_t)
|
||||
|
||||
SliceIndex __nac3_slice_index_bound(SliceIndex i, const SliceIndex len) {
|
||||
if (i < 0) {
|
||||
i = len + i;
|
||||
}
|
||||
if (i < 0) {
|
||||
return 0;
|
||||
} else if (i > len) {
|
||||
return len;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
SliceIndex __nac3_range_slice_len(
|
||||
const SliceIndex start,
|
||||
const SliceIndex end,
|
||||
const SliceIndex step
|
||||
) {
|
||||
SliceIndex diff = end - start;
|
||||
if (diff > 0 && step > 0) {
|
||||
return ((diff - 1) / step) + 1;
|
||||
} else if (diff < 0 && step < 0) {
|
||||
return ((diff + 1) / step) + 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle list assignment and dropping part of the list when
|
||||
// both dest_step and src_step are +1.
|
||||
// - All the index must *not* be out-of-bound or negative,
|
||||
// - The end index is *inclusive*,
|
||||
// - The length of src and dest slice size should already
|
||||
// be checked: if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest)
|
||||
SliceIndex __nac3_list_slice_assign_var_size(
|
||||
SliceIndex dest_start,
|
||||
SliceIndex dest_end,
|
||||
SliceIndex dest_step,
|
||||
uint8_t* dest_arr,
|
||||
SliceIndex dest_arr_len,
|
||||
SliceIndex src_start,
|
||||
SliceIndex src_end,
|
||||
SliceIndex src_step,
|
||||
uint8_t* src_arr,
|
||||
SliceIndex src_arr_len,
|
||||
const SliceIndex size
|
||||
) {
|
||||
/* if dest_arr_len == 0, do nothing since we do not support extending list */
|
||||
if (dest_arr_len == 0) return dest_arr_len;
|
||||
/* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */
|
||||
if (src_step == dest_step && dest_step == 1) {
|
||||
const SliceIndex src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0;
|
||||
const SliceIndex dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0;
|
||||
if (src_len > 0) {
|
||||
__builtin_memmove(
|
||||
dest_arr + dest_start * size,
|
||||
src_arr + src_start * size,
|
||||
src_len * size
|
||||
);
|
||||
}
|
||||
if (dest_len > 0) {
|
||||
/* dropping */
|
||||
__builtin_memmove(
|
||||
dest_arr + (dest_start + src_len) * size,
|
||||
dest_arr + (dest_end + 1) * size,
|
||||
(dest_arr_len - dest_end - 1) * size
|
||||
);
|
||||
}
|
||||
/* shrink size */
|
||||
return dest_arr_len - (dest_len - src_len);
|
||||
}
|
||||
/* if two range overlaps, need alloca */
|
||||
uint8_t need_alloca =
|
||||
(dest_arr == src_arr)
|
||||
&& !(
|
||||
max(dest_start, dest_end) < min(src_start, src_end)
|
||||
|| max(src_start, src_end) < min(dest_start, dest_end)
|
||||
);
|
||||
if (need_alloca) {
|
||||
uint8_t* tmp = reinterpret_cast<uint8_t *>(__builtin_alloca(src_arr_len * size));
|
||||
__builtin_memcpy(tmp, src_arr, src_arr_len * size);
|
||||
src_arr = tmp;
|
||||
}
|
||||
SliceIndex src_ind = src_start;
|
||||
SliceIndex dest_ind = dest_start;
|
||||
for (;
|
||||
(src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end);
|
||||
src_ind += src_step, dest_ind += dest_step
|
||||
) {
|
||||
/* for constant optimization */
|
||||
if (size == 1) {
|
||||
__builtin_memcpy(dest_arr + dest_ind, src_arr + src_ind, 1);
|
||||
} else if (size == 4) {
|
||||
__builtin_memcpy(dest_arr + dest_ind * 4, src_arr + src_ind * 4, 4);
|
||||
} else if (size == 8) {
|
||||
__builtin_memcpy(dest_arr + dest_ind * 8, src_arr + src_ind * 8, 8);
|
||||
} else {
|
||||
/* memcpy for var size, cannot overlap after previous alloca */
|
||||
__builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size);
|
||||
}
|
||||
}
|
||||
/* only dest_step == 1 can we shrink the dest list. */
|
||||
/* size should be ensured prior to calling this function */
|
||||
if (dest_step == 1 && dest_end >= dest_start) {
|
||||
__builtin_memmove(
|
||||
dest_arr + dest_ind * size,
|
||||
dest_arr + (dest_end + 1) * size,
|
||||
(dest_arr_len - dest_end - 1) * size
|
||||
);
|
||||
return dest_arr_len - (dest_end - dest_ind) - 1;
|
||||
}
|
||||
return dest_arr_len;
|
||||
}
|
||||
|
||||
int32_t __nac3_isinf(double x) {
|
||||
return __builtin_isinf(x);
|
||||
}
|
||||
|
||||
int32_t __nac3_isnan(double x) {
|
||||
return __builtin_isnan(x);
|
||||
}
|
||||
|
||||
double tgamma(double arg);
|
||||
|
||||
double __nac3_gamma(double z) {
|
||||
// Handling for denormals
|
||||
// | x | Python gamma(x) | C tgamma(x) |
|
||||
// --- | ----------------- | --------------- | ----------- |
|
||||
// (1) | nan | nan | nan |
|
||||
// (2) | -inf | -inf | inf |
|
||||
// (3) | inf | inf | inf |
|
||||
// (4) | 0.0 | inf | inf |
|
||||
// (5) | {-1.0, -2.0, ...} | inf | nan |
|
||||
|
||||
// (1)-(3)
|
||||
if (__builtin_isinf(z) || __builtin_isnan(z)) {
|
||||
return z;
|
||||
}
|
||||
|
||||
double v = tgamma(z);
|
||||
|
||||
// (4)-(5)
|
||||
return __builtin_isinf(v) || __builtin_isnan(v) ? __builtin_inf() : v;
|
||||
}
|
||||
|
||||
double lgamma(double arg);
|
||||
|
||||
double __nac3_gammaln(double x) {
|
||||
// libm's handling of value overflows differs from scipy:
|
||||
// - scipy: gammaln(-inf) -> -inf
|
||||
// - libm : lgamma(-inf) -> inf
|
||||
|
||||
if (__builtin_isinf(x)) {
|
||||
return x;
|
||||
}
|
||||
|
||||
return lgamma(x);
|
||||
}
|
||||
|
||||
double j0(double x);
|
||||
|
||||
double __nac3_j0(double x) {
|
||||
// libm's handling of value overflows differs from scipy:
|
||||
// - scipy: j0(inf) -> nan
|
||||
// - libm : j0(inf) -> 0.0
|
||||
|
||||
if (__builtin_isinf(x)) {
|
||||
return __builtin_nan("");
|
||||
}
|
||||
|
||||
return j0(x);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_calc_size(
|
||||
const uint32_t* list_data,
|
||||
uint32_t list_len,
|
||||
uint32_t begin_idx,
|
||||
uint32_t end_idx
|
||||
) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_calc_size64(
|
||||
const uint64_t* list_data,
|
||||
uint64_t list_len,
|
||||
uint64_t begin_idx,
|
||||
uint64_t end_idx
|
||||
) {
|
||||
return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices(
|
||||
uint32_t index,
|
||||
const uint32_t* dims,
|
||||
uint32_t num_dims,
|
||||
NDIndex* idxs
|
||||
) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices64(
|
||||
uint64_t index,
|
||||
const uint64_t* dims,
|
||||
uint64_t num_dims,
|
||||
NDIndex* idxs
|
||||
) {
|
||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_flatten_index(
|
||||
const uint32_t* dims,
|
||||
uint32_t num_dims,
|
||||
const NDIndex* indices,
|
||||
uint32_t num_indices
|
||||
) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_flatten_index64(
|
||||
const uint64_t* dims,
|
||||
uint64_t num_dims,
|
||||
const NDIndex* indices,
|
||||
uint64_t num_indices
|
||||
) {
|
||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast(
|
||||
const uint32_t* lhs_dims,
|
||||
uint32_t lhs_ndims,
|
||||
const uint32_t* rhs_dims,
|
||||
uint32_t rhs_ndims,
|
||||
uint32_t* out_dims
|
||||
) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast64(
|
||||
const uint64_t* lhs_dims,
|
||||
uint64_t lhs_ndims,
|
||||
const uint64_t* rhs_dims,
|
||||
uint64_t rhs_ndims,
|
||||
uint64_t* out_dims
|
||||
) {
|
||||
return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx(
|
||||
const uint32_t* src_dims,
|
||||
uint32_t src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx
|
||||
) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx64(
|
||||
const uint64_t* src_dims,
|
||||
uint64_t src_ndims,
|
||||
const NDIndex* in_idx,
|
||||
NDIndex* out_idx
|
||||
) {
|
||||
__nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx);
|
||||
}
|
||||
} // extern "C"
|
@ -1,29 +1,28 @@
|
||||
use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type};
|
||||
use crate::typecheck::typedef::Type;
|
||||
|
||||
use super::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
|
||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||
},
|
||||
llvm_intrinsics,
|
||||
macros::codegen_unreachable,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
CodeGenContext, CodeGenerator,
|
||||
llvm_intrinsics, CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
context::Context,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::Module,
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||
values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue},
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use itertools::Either;
|
||||
use nac3parser::ast::Expr;
|
||||
|
||||
#[must_use]
|
||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||
pub fn load_irrt(ctx: &Context) -> Module {
|
||||
let bitcode_buf = MemoryBuffer::create_from_memory_range(
|
||||
include_bytes!(concat!(env!("OUT_DIR"), "/irrt.bc")),
|
||||
"irrt_bitcode_buffer",
|
||||
@ -39,25 +38,6 @@ pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver)
|
||||
let function = irrt_mod.get_function(symbol).unwrap();
|
||||
function.add_attribute(AttributeLoc::Function, ctx.create_enum_attribute(inline_attr, 0));
|
||||
}
|
||||
|
||||
// Initialize all global `EXN_*` exception IDs in IRRT with the [`SymbolResolver`].
|
||||
let exn_id_type = ctx.i32_type();
|
||||
let errors = &[
|
||||
("EXN_INDEX_ERROR", "0:IndexError"),
|
||||
("EXN_VALUE_ERROR", "0:ValueError"),
|
||||
("EXN_ASSERTION_ERROR", "0:AssertionError"),
|
||||
("EXN_TYPE_ERROR", "0:TypeError"),
|
||||
];
|
||||
for (irrt_name, symbol_name) in errors {
|
||||
let exn_id = symbol_resolver.get_string_id(symbol_name);
|
||||
let exn_id = exn_id_type.const_int(exn_id as u64, false).as_basic_value_enum();
|
||||
|
||||
let global = irrt_mod.get_global(irrt_name).unwrap_or_else(|| {
|
||||
panic!("Exception symbol name '{irrt_name}' should exist in the IRRT LLVM module")
|
||||
});
|
||||
global.set_initializer(&exn_id);
|
||||
}
|
||||
|
||||
irrt_mod
|
||||
}
|
||||
|
||||
@ -75,7 +55,7 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
(64, 64, true) => "__nac3_int_exp_int64_t",
|
||||
(32, 32, false) => "__nac3_int_exp_uint32_t",
|
||||
(64, 64, false) => "__nac3_int_exp_uint64_t",
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let base_type = base.get_type();
|
||||
let pow_fun = ctx.module.get_function(symbol).unwrap_or_else(|| {
|
||||
@ -461,7 +441,7 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
|
||||
BasicTypeEnum::IntType(t) => t.size_of(),
|
||||
BasicTypeEnum::PointerType(t) => t.size_of(),
|
||||
BasicTypeEnum::StructType(t) => t.size_of().unwrap(),
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size").unwrap()
|
||||
}
|
||||
@ -588,8 +568,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo
|
||||
///
|
||||
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
|
||||
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
|
||||
/// or [`None`] if starting from the first dimension and ending at the last dimension
|
||||
/// respectively.
|
||||
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
|
||||
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
|
||||
generator: &G,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
@ -606,7 +585,7 @@ where
|
||||
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_size",
|
||||
64 => "__nac3_ndarray_calc_size64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
|
||||
@ -641,7 +620,7 @@ where
|
||||
///
|
||||
/// * `index` - The index to compute the multidimensional index for.
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// `NDArray`.
|
||||
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -657,7 +636,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_nd_indices",
|
||||
64 => "__nac3_ndarray_calc_nd_indices64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_nd_indices_fn =
|
||||
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
|
||||
@ -726,7 +705,7 @@ where
|
||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_flatten_index",
|
||||
64 => "__nac3_ndarray_flatten_index64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_flatten_index_fn =
|
||||
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||
@ -765,7 +744,7 @@ where
|
||||
/// multidimensional index.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
||||
generator: &mut G,
|
||||
@ -794,7 +773,7 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast",
|
||||
64 => "__nac3_ndarray_calc_broadcast64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
@ -914,7 +893,7 @@ pub fn call_ndarray_calc_broadcast_index<
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast_idx",
|
||||
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw),
|
||||
};
|
||||
let ndarray_calc_broadcast_fn =
|
||||
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
|
@ -205,9 +205,8 @@ pub fn call_memcpy_generic<'ctx>(
|
||||
/// * `$ctx:ident`: Reference to the current Code Generation Context
|
||||
/// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>)
|
||||
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
|
||||
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type).
|
||||
/// Use `BasicValueEnum::into_int_value` for Integer return type and
|
||||
/// `BasicValueEnum::into_float_value` for Float return type
|
||||
/// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type)
|
||||
/// Use `BasicValueEnum::into_int_value` for Integer return type and `BasicValueEnum::into_float_value` for Float return type
|
||||
/// * `$llvm_ty:ident`: Type of first operand
|
||||
/// * `,($val:ident)*`: Comma separated list of operands
|
||||
macro_rules! generate_llvm_intrinsic_fn_body {
|
||||
@ -223,8 +222,8 @@ macro_rules! generate_llvm_intrinsic_fn_body {
|
||||
/// Arguments:
|
||||
/// * `float/int`: Indicates the return and argument type of the function
|
||||
/// * `$fn_name:ident`: The identifier of the rust function to be generated
|
||||
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function.
|
||||
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
|
||||
/// * `$llvm_name:literal`: Name of underlying llvm intrinsic function
|
||||
/// Omit "llvm." prefix from the function name i.e. use "ceil" instead of "llvm.ceil"
|
||||
/// * `$val:ident`: The operand for unary operations
|
||||
/// * `$val1:ident`, `$val2:ident`: The operands for binary operations
|
||||
macro_rules! generate_llvm_intrinsic_fn {
|
||||
|
@ -50,22 +50,6 @@ mod test;
|
||||
use concrete_type::{ConcreteType, ConcreteTypeEnum, ConcreteTypeStore};
|
||||
pub use generator::{CodeGenerator, DefaultCodeGenerator};
|
||||
|
||||
mod macros {
|
||||
/// Codegen-variant of [`std::unreachable`] which accepts an instance of [`CodeGenContext`] as
|
||||
/// its first argument to provide Python source information to indicate the codegen location
|
||||
/// causing the assertion.
|
||||
macro_rules! codegen_unreachable {
|
||||
($ctx:expr $(,)?) => {
|
||||
std::unreachable!("unreachable code while processing {}", &$ctx.current_loc)
|
||||
};
|
||||
($ctx:expr, $($arg:tt)*) => {
|
||||
std::unreachable!("unreachable code while processing {}: {}", &$ctx.current_loc, std::format!("{}", std::format_args!($($arg)+)))
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use codegen_unreachable;
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct StaticValueStore {
|
||||
pub lookup: HashMap<Vec<(usize, u64)>, usize>,
|
||||
@ -596,11 +580,11 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
|
||||
) -> BasicTypeEnum<'ctx> {
|
||||
// If the type is used in the definition of a function, return `i1` instead of `i8` for ABI
|
||||
// consistency.
|
||||
if unifier.unioned(ty, primitives.bool) {
|
||||
return if unifier.unioned(ty, primitives.bool) {
|
||||
ctx.bool_type().into()
|
||||
} else {
|
||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Whether `sret` is needed for a return value with type `ty`.
|
||||
|
@ -12,7 +12,6 @@ use crate::{
|
||||
call_ndarray_calc_size,
|
||||
},
|
||||
llvm_intrinsics::{self, call_memcpy_generic},
|
||||
macros::codegen_unreachable,
|
||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
@ -260,7 +259,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").into()
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
@ -288,7 +287,7 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").into()
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
@ -356,7 +355,7 @@ fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -627,7 +626,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
||||
fill_value
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
@ -1072,15 +1071,15 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Copies a slice of an [`NDArrayValue`] to another.
|
||||
///
|
||||
/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz`
|
||||
/// fields should be populated before calling this function.
|
||||
/// fields should be populated before calling this function.
|
||||
/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||
/// dimensional slice in the destination array.
|
||||
/// dimensional slice in the destination array.
|
||||
/// - `src_arr`: The [`NDArrayValue`] instance of the source array.
|
||||
/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing
|
||||
/// dimensional slice in the source array.
|
||||
/// dimensional slice in the source array.
|
||||
/// - `dim`: The index of the currently processing dimension.
|
||||
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
|
||||
/// this dimension. The `start`/`stop` values of each slice must be non-negative indices.
|
||||
fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -1185,7 +1184,7 @@ fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to
|
||||
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
|
||||
/// this dimension. The `start`/`stop` values of each slice must be positive indices.
|
||||
pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -1350,7 +1349,7 @@ where
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||
/// written to a new `ndarray`.
|
||||
/// written to a new `ndarray`.
|
||||
/// * `value_fn` - Function mapping the two input elements into the result.
|
||||
///
|
||||
/// # Panic
|
||||
@ -1437,7 +1436,7 @@ where
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||
/// written to a new `ndarray`.
|
||||
/// written to a new `ndarray`.
|
||||
pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -2021,7 +2020,7 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||
} else if value_arg.is_int_value() || value_arg.is_float_value() {
|
||||
value_arg
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
@ -2130,8 +2129,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(out.as_base_value().into())
|
||||
} else {
|
||||
codegen_unreachable!(
|
||||
ctx,
|
||||
unreachable!(
|
||||
"{FN_NAME}() not supported for '{}'",
|
||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||
)
|
||||
@ -2142,12 +2140,11 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
///
|
||||
/// * `x1` - `NDArray` to reshape.
|
||||
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
||||
/// Just like numpy, the `shape` argument can be:
|
||||
/// Just like numpy, the `shape` argument can be:
|
||||
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
||||
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||
///
|
||||
/// Note that unlike other generating functions, one of the dimensions in the shape can be negative.
|
||||
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
|
||||
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
@ -2373,7 +2370,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.into_int_value();
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
@ -2417,8 +2414,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
||||
Ok(out.as_base_value().into())
|
||||
} else {
|
||||
codegen_unreachable!(
|
||||
ctx,
|
||||
unreachable!(
|
||||
"{FN_NAME}() not supported for '{}'",
|
||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||
)
|
||||
@ -2486,7 +2482,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.build_float_mul(e1, elem2.into_float_value(), "")
|
||||
.unwrap()
|
||||
.as_basic_value_enum(),
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||
let acc_val = match acc_val {
|
||||
@ -2500,7 +2496,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.build_float_add(e1, product.into_float_value(), "")
|
||||
.unwrap()
|
||||
.as_basic_value_enum(),
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||
|
||||
@ -2517,8 +2513,7 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
||||
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||
}
|
||||
_ => codegen_unreachable!(
|
||||
ctx,
|
||||
_ => unreachable!(
|
||||
"{FN_NAME}() not supported for '{}'",
|
||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||
),
|
||||
|
@ -1,13 +1,15 @@
|
||||
use super::{
|
||||
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
||||
expr::{destructure_range, gen_binop_expr},
|
||||
gen_in_range_check,
|
||||
super::symbol_resolver::ValueEnum,
|
||||
expr::destructure_range,
|
||||
irrt::{handle_slice_indices, list_slice_assignment},
|
||||
macros::codegen_unreachable,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::{
|
||||
symbol_resolver::ValueEnum,
|
||||
codegen::{
|
||||
classes::{ArrayLikeIndexer, ArraySliceValue, ListValue, RangeValue},
|
||||
expr::gen_binop_expr,
|
||||
gen_in_range_check,
|
||||
},
|
||||
toplevel::{DefinitionId, TopLevelDef},
|
||||
typecheck::{
|
||||
magic_methods::Binop,
|
||||
@ -119,7 +121,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||
return Ok(None);
|
||||
};
|
||||
let BasicValueEnum::PointerValue(ptr) = val else {
|
||||
codegen_unreachable!(ctx);
|
||||
unreachable!();
|
||||
};
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
@ -133,7 +135,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
.unwrap()
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
_ => unreachable!(),
|
||||
}))
|
||||
}
|
||||
|
||||
@ -174,14 +176,6 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
}
|
||||
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
|
||||
|
||||
// Perform i1 <-> i8 conversion as needed
|
||||
let val = if ctx.unifier.unioned(target.custom.unwrap(), ctx.primitives.bool) {
|
||||
generator.bool_to_i8(ctx, val.into_int_value()).into()
|
||||
} else {
|
||||
val
|
||||
};
|
||||
|
||||
ctx.builder.build_store(ptr, val).unwrap();
|
||||
}
|
||||
};
|
||||
@ -199,12 +193,12 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
|
||||
// Deconstruct the tuple `value`
|
||||
let BasicValueEnum::StructValue(tuple) = value.to_basic_value_enum(ctx, generator, value_ty)?
|
||||
else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
// NOTE: Currently, RHS's type is forced to be a Tuple by the type inferencer.
|
||||
let TypeEnum::TTuple { ty: tuple_tys, .. } = &*ctx.unifier.get_ty(value_ty) else {
|
||||
codegen_unreachable!(ctx);
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
assert_eq!(tuple.get_type().count_fields() as usize, tuple_tys.len());
|
||||
@ -264,7 +258,7 @@ pub fn gen_assign_target_list<'ctx, G: CodeGenerator>(
|
||||
// Now assign with that sub-tuple to the starred target.
|
||||
generator.gen_assign(ctx, target, ValueEnum::Dynamic(sub_tuple_val), sub_tuple_ty)?;
|
||||
} else {
|
||||
codegen_unreachable!(ctx) // The typechecker ensures this
|
||||
unreachable!() // The typechecker ensures this
|
||||
}
|
||||
|
||||
// Handle assignment after the starred target
|
||||
@ -312,9 +306,7 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||
|
||||
if let ExprKind::Slice { .. } = &key.node {
|
||||
// Handle assigning to a slice
|
||||
let ExprKind::Slice { lower, upper, step } = &key.node else {
|
||||
codegen_unreachable!(ctx)
|
||||
};
|
||||
let ExprKind::Slice { lower, upper, step } = &key.node else { unreachable!() };
|
||||
let Some((start, end, step)) = handle_slice_indices(
|
||||
lower,
|
||||
upper,
|
||||
@ -424,9 +416,7 @@ pub fn gen_for<G: CodeGenerator>(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String> {
|
||||
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else {
|
||||
codegen_unreachable!(ctx)
|
||||
};
|
||||
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() };
|
||||
|
||||
// var_assignment static values may be changed in another branch
|
||||
// if so, remove the static value as it may not be correct in this branch
|
||||
@ -468,7 +458,7 @@ pub fn gen_for<G: CodeGenerator>(
|
||||
let Some(target_i) =
|
||||
generator.gen_store_target(ctx, target, Some("for.target.addr"))?
|
||||
else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
let (start, stop, step) = destructure_range(ctx, iter_val);
|
||||
|
||||
@ -629,9 +619,9 @@ pub struct BreakContinueHooks<'ctx> {
|
||||
/// ```
|
||||
///
|
||||
/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The
|
||||
/// return value is a [Clone] value which will be passed to the other lambdas.
|
||||
/// return value is a [Clone] value which will be passed to the other lambdas.
|
||||
/// * `cond` - A lambda containing IR statements checking whether the loop should continue
|
||||
/// executing. The result value must be an `i1` indicating if the loop should continue.
|
||||
/// executing. The result value must be an `i1` indicating if the loop should continue.
|
||||
/// * `body` - A lambda containing IR statements within the loop body.
|
||||
/// * `update` - A lambda containing IR statements updating loop variables.
|
||||
pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
||||
@ -648,12 +638,8 @@ where
|
||||
I: Clone,
|
||||
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
||||
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
||||
BodyFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
I,
|
||||
) -> Result<(), String>,
|
||||
BodyFn:
|
||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, BreakContinueHooks, I) -> Result<(), String>,
|
||||
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||
{
|
||||
let label = label.unwrap_or("for");
|
||||
@ -714,9 +700,9 @@ where
|
||||
/// ```
|
||||
///
|
||||
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
|
||||
/// as the type of the loop variable.
|
||||
/// as the type of the loop variable.
|
||||
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
|
||||
/// value should be treated as inclusive (as opposed to exclusive).
|
||||
/// value should be treated as inclusive (as opposed to exclusive).
|
||||
/// * `body` - A lambda containing IR statements within the loop body.
|
||||
/// * `incr_val` - The value to increment the loop variable on each iteration.
|
||||
pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
|
||||
@ -733,7 +719,7 @@ where
|
||||
BodyFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
BreakContinueHooks,
|
||||
IntValue<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
@ -787,12 +773,12 @@ where
|
||||
///
|
||||
/// - `is_unsigned`: Whether to treat the values of the `range` as unsigned.
|
||||
/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like
|
||||
/// iterable.
|
||||
/// iterable.
|
||||
/// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like
|
||||
/// iterable. This value will be extended to the size of `start`.
|
||||
/// iterable. This value will be extended to the size of `start`.
|
||||
/// - `stop_inclusive`: Whether the stop value should be treated as inclusive.
|
||||
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
|
||||
/// iterable. This value will be extended to the size of `start`.
|
||||
/// iterable. This value will be extended to the size of `start`.
|
||||
/// - `body_fn`: A lambda of IR statements within the loop body.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
|
||||
@ -813,7 +799,7 @@ where
|
||||
BodyFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
BreakContinueHooks<'ctx>,
|
||||
BreakContinueHooks,
|
||||
IntValue<'ctx>,
|
||||
) -> Result<(), String>,
|
||||
{
|
||||
@ -911,7 +897,7 @@ pub fn gen_while<G: CodeGenerator>(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String> {
|
||||
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
|
||||
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() };
|
||||
|
||||
// var_assignment static values may be changed in another branch
|
||||
// if so, remove the static value as it may not be correct in this branch
|
||||
@ -941,7 +927,7 @@ pub fn gen_while<G: CodeGenerator>(
|
||||
|
||||
return Ok(());
|
||||
};
|
||||
let BasicValueEnum::IntValue(test) = test else { codegen_unreachable!(ctx) };
|
||||
let BasicValueEnum::IntValue(test) = test else { unreachable!() };
|
||||
|
||||
ctx.builder
|
||||
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
|
||||
@ -1089,7 +1075,7 @@ pub fn gen_if<G: CodeGenerator>(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
stmt: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String> {
|
||||
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { codegen_unreachable!(ctx) };
|
||||
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { unreachable!() };
|
||||
|
||||
// var_assignment static values may be changed in another branch
|
||||
// if so, remove the static value as it may not be correct in this branch
|
||||
@ -1212,11 +1198,11 @@ pub fn exn_constructor<'ctx>(
|
||||
let zelf_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) {
|
||||
obj_id.0
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
let defs = ctx.top_level.definitions.read();
|
||||
let def = defs[zelf_id].read();
|
||||
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { codegen_unreachable!(ctx) };
|
||||
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { unreachable!() };
|
||||
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name);
|
||||
unsafe {
|
||||
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap();
|
||||
@ -1324,7 +1310,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
|
||||
target: &Stmt<Option<Type>>,
|
||||
) -> Result<(), String> {
|
||||
let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
// if we need to generate anything related to exception, we must have personality defined
|
||||
@ -1401,7 +1387,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
|
||||
if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
|
||||
*obj_id
|
||||
} else {
|
||||
codegen_unreachable!(ctx)
|
||||
unreachable!()
|
||||
};
|
||||
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
|
||||
let exn_id = ctx.resolver.get_string_id(&exception_name);
|
||||
@ -1673,23 +1659,6 @@ pub fn gen_return<G: CodeGenerator>(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Remap boolean return type into i1
|
||||
let value = value.map(|ret_val| {
|
||||
// The "return type" of a sret function is in the first parameter
|
||||
let expected_ty = if ctx.need_sret {
|
||||
func.get_type().get_param_types()[0]
|
||||
} else {
|
||||
func.get_type().get_return_type().unwrap()
|
||||
};
|
||||
|
||||
if matches!(expected_ty, BasicTypeEnum::IntType(ty) if ty.get_bit_width() == 1) {
|
||||
generator.bool_to_i1(ctx, ret_val.into_int_value()).into()
|
||||
} else {
|
||||
ret_val
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(return_target) = ctx.return_target {
|
||||
if let Some(value) = value {
|
||||
ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap();
|
||||
@ -1700,6 +1669,25 @@ pub fn gen_return<G: CodeGenerator>(
|
||||
ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap();
|
||||
ctx.builder.build_return(None).unwrap();
|
||||
} else {
|
||||
// Remap boolean return type into i1
|
||||
let value = value.map(|v| {
|
||||
let expected_ty = func.get_type().get_return_type().unwrap();
|
||||
let ret_val = v.as_basic_value_enum();
|
||||
|
||||
if expected_ty.is_int_type() && ret_val.is_int_value() {
|
||||
let ret_type = expected_ty.into_int_type();
|
||||
let ret_val = ret_val.into_int_value();
|
||||
|
||||
if ret_type.get_bit_width() == 1 && ret_val.get_type().get_bit_width() != 1 {
|
||||
generator.bool_to_i1(ctx, ret_val)
|
||||
} else {
|
||||
ret_val
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
ret_val
|
||||
}
|
||||
});
|
||||
let value = value.as_ref().map(|v| v as &dyn BasicValue);
|
||||
ctx.builder.build_return(value).unwrap();
|
||||
}
|
||||
@ -1768,30 +1756,7 @@ pub fn gen_stmt<G: CodeGenerator>(
|
||||
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
|
||||
StmtKind::Raise { exc, .. } => {
|
||||
if let Some(exc) = exc {
|
||||
let exn = if let ExprKind::Name { id, .. } = &exc.node {
|
||||
// Handle "raise Exception" short form
|
||||
let def_id = ctx.resolver.get_identifier_def(*id).map_err(|e| {
|
||||
format!("{} (at {})", e.iter().next().unwrap(), exc.location)
|
||||
})?;
|
||||
let def = ctx.top_level.definitions.read();
|
||||
let TopLevelDef::Class { constructor, .. } = *def[def_id.0].read() else {
|
||||
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
|
||||
};
|
||||
|
||||
let TypeEnum::TFunc(signature) =
|
||||
ctx.unifier.get_ty(constructor.unwrap()).as_ref().clone()
|
||||
else {
|
||||
return Err(format!("Failed to resolve symbol {id} (at {})", exc.location));
|
||||
};
|
||||
|
||||
generator
|
||||
.gen_call(ctx, None, (&signature, def_id), Vec::default())?
|
||||
.map(Into::into)
|
||||
} else {
|
||||
generator.gen_expr(ctx, exc)?
|
||||
};
|
||||
|
||||
let exc = if let Some(v) = exn {
|
||||
let exc = if let Some(v) = generator.gen_expr(ctx, exc)? {
|
||||
v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())?
|
||||
} else {
|
||||
return Ok(());
|
||||
|
@ -112,7 +112,7 @@ pub fn get_exn_constructor(
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `codegen_callback`: A lambda generating LLVM IR for the implementation of this function.
|
||||
fn create_fn_by_codegen(
|
||||
unifier: &mut Unifier,
|
||||
@ -152,7 +152,7 @@ fn create_fn_by_codegen(
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `intrinsic_fn`: The fully-qualified name of the LLVM intrinsic function.
|
||||
fn create_fn_by_intrinsic(
|
||||
unifier: &mut Unifier,
|
||||
@ -214,10 +214,10 @@ fn create_fn_by_intrinsic(
|
||||
/// * `name`: The name of the implemented NumPy function.
|
||||
/// * `ret_ty`: The return type of this function.
|
||||
/// * `param_ty`: The parameters accepted by this function, represented by a tuple of the
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// [parameter type][Type] and the parameter symbol name.
|
||||
/// * `extern_fn`: The fully-qualified name of the extern function used as the implementation.
|
||||
/// * `attrs`: The list of attributes to apply to this function declaration. Note that `nounwind` is
|
||||
/// already implied by the C ABI.
|
||||
/// already implied by the C ABI.
|
||||
fn create_fn_by_extern(
|
||||
unifier: &mut Unifier,
|
||||
var_map: &VarMap,
|
||||
|
@ -1822,9 +1822,8 @@ impl TopLevelComposer {
|
||||
if *name != init_str_id {
|
||||
unreachable!("must be init function here")
|
||||
}
|
||||
|
||||
let all_inited = Self::get_all_assigned_field(
|
||||
object_id.0,
|
||||
class_name.to_string().into(),
|
||||
definition_ast_list,
|
||||
body.as_slice(),
|
||||
)?;
|
||||
|
@ -734,14 +734,10 @@ impl TopLevelComposer {
|
||||
)
|
||||
}
|
||||
|
||||
/// This function returns the fields that have been initialized in the `__init__` function of a class
|
||||
/// The function takes as input:
|
||||
/// * `class_id`: The `object_id` of the class whose function is being evaluated (check `TopLevelDef::Class`)
|
||||
/// * `definition_ast_list`: A list of ast definitions and statements defined in `TopLevelComposer`
|
||||
/// * `stmts`: The body of function being parsed. Each statment is analyzed to check varaible initialization statements
|
||||
#[allow(clippy::only_used_in_recursion)]
|
||||
pub fn get_all_assigned_field(
|
||||
class_id: usize,
|
||||
definition_ast_list: &Vec<DefAst>,
|
||||
class_name: StrRef,
|
||||
ast: &Vec<DefAst>,
|
||||
stmts: &[Stmt<()>],
|
||||
) -> Result<HashSet<StrRef>, HashSet<String>> {
|
||||
let mut result = HashSet::new();
|
||||
@ -779,60 +775,46 @@ impl TopLevelComposer {
|
||||
// TODO: do not check for For and While?
|
||||
ast::StmtKind::For { body, orelse, .. }
|
||||
| ast::StmtKind::While { body, orelse, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(class_name, ast, body.as_slice())?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
class_name,
|
||||
ast,
|
||||
orelse.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::If { body, orelse, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
body.as_slice(),
|
||||
)?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
let inited_for_sure =
|
||||
Self::get_all_assigned_field(class_name, ast, body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
class_name,
|
||||
ast,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
}
|
||||
ast::StmtKind::Try { body, orelse, finalbody, .. } => {
|
||||
let inited_for_sure = Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
body.as_slice(),
|
||||
)?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
let inited_for_sure =
|
||||
Self::get_all_assigned_field(class_name, ast, body.as_slice())?
|
||||
.intersection(&Self::get_all_assigned_field(
|
||||
class_name,
|
||||
ast,
|
||||
orelse.as_slice(),
|
||||
)?)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
result.extend(inited_for_sure);
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
class_name,
|
||||
ast,
|
||||
finalbody.as_slice(),
|
||||
)?);
|
||||
}
|
||||
ast::StmtKind::With { body, .. } => {
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
body.as_slice(),
|
||||
)?);
|
||||
result.extend(Self::get_all_assigned_field(class_name, ast, body.as_slice())?);
|
||||
}
|
||||
// Variables Initialized in function calls
|
||||
// Variables Initiated in function calls
|
||||
ast::StmtKind::Expr { value, .. } => {
|
||||
let ExprKind::Call { func, .. } = &value.node else {
|
||||
continue;
|
||||
@ -843,70 +825,23 @@ impl TopLevelComposer {
|
||||
let ExprKind::Name { id, .. } = &value.node else {
|
||||
continue;
|
||||
};
|
||||
// Need to consider the two cases:
|
||||
// Need to conside the two cases:
|
||||
// Case 1) Call to class function i.e. id = `self`
|
||||
// Case 2) Call to class ancestor function i.e. id = ancestor_name
|
||||
// We leave checking whether function in case 2 belonged to class ancestor or not to type checker
|
||||
//
|
||||
// According to current handling of `self`, function definition are fixed and do not change regardless
|
||||
// of which object is passed as `self` i.e. virtual polymorphism is not supported
|
||||
// Therefore, we change class id for case 2 to reflect behavior of our compiler
|
||||
|
||||
let class_name = if *id == "self".into() {
|
||||
let ast::StmtKind::ClassDef { name, .. } =
|
||||
&definition_ast_list[class_id].1.as_ref().unwrap().node
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
name
|
||||
} else {
|
||||
id
|
||||
};
|
||||
|
||||
let parent_method = definition_ast_list.iter().find_map(|def| {
|
||||
let (
|
||||
class_def,
|
||||
Some(ast::Located {
|
||||
node: ast::StmtKind::ClassDef { name, body, .. },
|
||||
..
|
||||
}),
|
||||
) = &def
|
||||
else {
|
||||
return None;
|
||||
};
|
||||
let TopLevelDef::Class { object_id: class_id, .. } = &*class_def.read()
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if name == class_name {
|
||||
body.iter().find_map(|m| {
|
||||
let ast::StmtKind::FunctionDef { name, body, .. } = &m.node else {
|
||||
return None;
|
||||
};
|
||||
if *name == *attr {
|
||||
return Some((body.clone(), class_id.0));
|
||||
}
|
||||
None
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
// If method body is none then method does not exist
|
||||
if let Some((method_body, class_id)) = parent_method {
|
||||
result.extend(Self::get_all_assigned_field(
|
||||
class_id,
|
||||
definition_ast_list,
|
||||
method_body.as_slice(),
|
||||
)?);
|
||||
} else {
|
||||
return Err(HashSet::from([format!(
|
||||
"{}.{} not found in class {class_name} at {}",
|
||||
*id, *attr, value.location
|
||||
)]));
|
||||
}
|
||||
// We leave checking whether ancestor is called to type checker
|
||||
// if *id == "self".into() {
|
||||
// ast.iter().find_map(|def| {
|
||||
// let Some(ast::Located {
|
||||
// node: ast::StmtKind::ClassDef { name, body, .. },
|
||||
// ..
|
||||
// }) = def.1
|
||||
// else {
|
||||
// return None;
|
||||
// };
|
||||
// if *name == class_name {}
|
||||
// None
|
||||
// });
|
||||
// }
|
||||
}
|
||||
ast::StmtKind::Pass { .. }
|
||||
| ast::StmtKind::Assert { .. }
|
||||
|
@ -130,14 +130,14 @@ pub enum TopLevelDef {
|
||||
/// Function instance to symbol mapping
|
||||
///
|
||||
/// * Key: String representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class.
|
||||
/// order, including type variables associated with the class.
|
||||
/// * Value: Function symbol name.
|
||||
instance_to_symbol: HashMap<String, String>,
|
||||
/// Function instances to annotated AST mapping
|
||||
///
|
||||
/// * Key: String representation of type variable values, sorted by variable ID in ascending
|
||||
/// order, including type variables associated with the class. Excluding rigid type
|
||||
/// variables.
|
||||
/// order, including type variables associated with the class. Excluding rigid type
|
||||
/// variables.
|
||||
///
|
||||
/// Rigid type variables that would be substituted when the function is instantiated.
|
||||
instance_to_stmt: HashMap<String, FunInstance>,
|
||||
|
@ -10,9 +10,9 @@ use itertools::Itertools;
|
||||
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
||||
///
|
||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// specialized.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// specialized.
|
||||
pub fn make_ndarray_ty(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
@ -25,9 +25,9 @@ pub fn make_ndarray_ty(
|
||||
/// Substitutes type variables in `ndarray`.
|
||||
///
|
||||
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// specialized.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||
/// specialized.
|
||||
/// specialized.
|
||||
pub fn subst_ndarray_tvars(
|
||||
unifier: &mut Unifier,
|
||||
ndarray: Type,
|
||||
|
@ -64,9 +64,9 @@ impl TypeAnnotation {
|
||||
/// Parses an AST expression `expr` into a [`TypeAnnotation`].
|
||||
///
|
||||
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
|
||||
/// generic variables associated with the definition.
|
||||
/// generic variables associated with the definition.
|
||||
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
|
||||
/// [`None`] when this function is invoked externally.
|
||||
/// [`None`] when this function is invoked externally.
|
||||
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
|
@ -520,23 +520,6 @@ pub fn typeof_binop(
|
||||
}
|
||||
|
||||
Operator::MatMult => {
|
||||
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
|
||||
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
|
||||
(
|
||||
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
|
||||
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
|
||||
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
|
||||
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
|
||||
{
|
||||
// LHS and RHS have valid types
|
||||
}
|
||||
_ => {
|
||||
let lhs_str = unifier.stringify(lhs);
|
||||
let rhs_str = unifier.stringify(rhs);
|
||||
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
|
||||
}
|
||||
}
|
||||
|
||||
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
|
||||
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
|
||||
TypeEnum::TLiteral { values, .. } => {
|
||||
@ -697,7 +680,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||
bool: bool_t,
|
||||
uint32: uint32_t,
|
||||
uint64: uint64_t,
|
||||
str: str_t,
|
||||
list: list_t,
|
||||
ndarray: ndarray_t,
|
||||
..
|
||||
@ -743,9 +725,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
||||
impl_sign(unifier, store, bool_t, Some(int32_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||
|
||||
/* str ========= */
|
||||
impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t));
|
||||
|
||||
/* list ======== */
|
||||
impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]);
|
||||
impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]);
|
||||
|
@ -103,7 +103,6 @@ pub struct Inferencer<'a> {
|
||||
}
|
||||
|
||||
type InferenceError = HashSet<String>;
|
||||
type OverrideResult = Result<Option<ast::Expr<Option<Type>>>, InferenceError>;
|
||||
|
||||
struct NaiveFolder();
|
||||
impl Fold<()> for NaiveFolder {
|
||||
@ -1676,10 +1675,15 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
/// Checks whether a class method is calling parent function
|
||||
/// Returns [`None`] if its not a call to parent method, otherwise
|
||||
/// returns a new `func` with class name replaced by `self` and method resolved to its `DefinitionID`
|
||||
/// returns a new `func` with class name replaced by `self` and class name store as `ExprKind::Constant`
|
||||
///
|
||||
/// e.g. A.f1(self, ...) returns Some(self.{DefintionID(f1)})
|
||||
fn check_overriding(&mut self, func: &ast::Expr<()>, args: &[ast::Expr<()>]) -> OverrideResult {
|
||||
/// e.g. A.f1(self, ...) returns Some(self.f1, Some(ExprKind::Constant(A))
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn check_overriding(
|
||||
&mut self,
|
||||
func: &ast::Expr<()>,
|
||||
args: &[ast::Expr<()>],
|
||||
) -> Result<Option<(ast::Expr<()>, Option<ast::Expr<Option<Type>>>)>, InferenceError> {
|
||||
// `self` must be first argument for call to parent method
|
||||
if let Some(Located { node: ExprKind::Name { id, .. }, .. }) = &args.first() {
|
||||
if *id != "self".into() {
|
||||
@ -1713,11 +1717,11 @@ impl<'a> Inferencer<'a> {
|
||||
};
|
||||
// Class names are stored as `__module__.class`
|
||||
let name = name.to_string();
|
||||
let (_, name) = name.rsplit_once('.').unwrap();
|
||||
let (_, name) = name.split_once('.').unwrap();
|
||||
if name == class_name.to_string() {
|
||||
return methods.iter().find_map(|f| {
|
||||
if f.0 == *method_name {
|
||||
return Some(*f);
|
||||
return Some(ast::Constant::Int(f.2 .0.try_into().unwrap()));
|
||||
}
|
||||
None
|
||||
});
|
||||
@ -1736,19 +1740,19 @@ impl<'a> Inferencer<'a> {
|
||||
let mut new_value = value.clone();
|
||||
new_value.node = ExprKind::Name { id: "self".into(), ctx: *class_ctx };
|
||||
new_func.node =
|
||||
ExprKind::Attribute { value: new_value.clone(), attr: *method_name, ctx: *ctx };
|
||||
ExprKind::Attribute { value: new_value, attr: *method_name, ctx: *ctx };
|
||||
|
||||
let mut new_func = self.fold_expr(new_func)?;
|
||||
let dummy_arg = self.fold_expr(Located {
|
||||
location: *location,
|
||||
custom: (),
|
||||
node: ExprKind::Constant::<()> { value: r, kind: None },
|
||||
})?;
|
||||
|
||||
let ExprKind::Attribute { value, .. } = new_func.node else { unreachable!() };
|
||||
new_func.node =
|
||||
ExprKind::Attribute { value, attr: r.2 .0.to_string().into(), ctx: *ctx };
|
||||
new_func.custom = Some(r.1);
|
||||
|
||||
Ok(Some(new_func))
|
||||
// args.remove (dummy_arg);
|
||||
Ok(Some((new_func, Some(dummy_arg))))
|
||||
}
|
||||
None => report_error(
|
||||
format!("Ancestor method [{class_name}.{method_name}] should be defined with same decorator as its overridden version").as_str(),
|
||||
format!("Method {class_name}.{method_name} not found in ancestor list").as_str(),
|
||||
*location,
|
||||
),
|
||||
}
|
||||
@ -1767,19 +1771,20 @@ impl<'a> Inferencer<'a> {
|
||||
return Ok(spec_call_func);
|
||||
}
|
||||
|
||||
let mut zelf = None;
|
||||
|
||||
// Check for call to parent method
|
||||
let override_res = self.check_overriding(&func, &args)?;
|
||||
let is_override = override_res.is_some();
|
||||
let func = if is_override { override_res.unwrap() } else { self.fold_expr(func)? };
|
||||
let func = Box::new(func);
|
||||
let (func, dummy_var) = override_res.unwrap_or((func, None));
|
||||
|
||||
let func = Box::new(self.fold_expr(func)?);
|
||||
let mut args =
|
||||
args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// TODO: Handle passing of self to functions to allow runtime lookup of functions to be called
|
||||
// Currently removing `self` and using compile time function definitions
|
||||
// Remove self from arguments
|
||||
if is_override {
|
||||
args.remove(0);
|
||||
zelf = Some(args.remove(0));
|
||||
}
|
||||
let keywords = keywords
|
||||
.into_iter()
|
||||
@ -1802,6 +1807,13 @@ impl<'a> Inferencer<'a> {
|
||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||
})?;
|
||||
|
||||
// Add `class_name` and `self` to arguments for `gen_expr` to generate call to parent method
|
||||
if let Some(mut arg) = zelf {
|
||||
arg.node = ExprKind::Name { id: "self".into(), ctx: ExprContext::Load };
|
||||
args.push(dummy_var.unwrap());
|
||||
args.push(arg);
|
||||
}
|
||||
return Ok(Located {
|
||||
location,
|
||||
custom: Some(sign.ret),
|
||||
|
@ -1,11 +1,3 @@
|
||||
use super::magic_methods::{Binop, HasOpInfo};
|
||||
use super::type_error::{TypeError, TypeErrorKind};
|
||||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::typecheck::magic_methods::OpInfo;
|
||||
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||
use indexmap::IndexMap;
|
||||
use itertools::{repeat_n, Itertools};
|
||||
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
|
||||
@ -17,6 +9,15 @@ use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::{borrow::Cow, collections::HashSet};
|
||||
|
||||
use super::magic_methods::Binop;
|
||||
use super::type_error::{TypeError, TypeErrorKind};
|
||||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::typecheck::magic_methods::OpInfo;
|
||||
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
@ -1007,18 +1008,8 @@ impl Unifier {
|
||||
self.unify_impl(v.ty, ty[ind as usize], false)
|
||||
.map_err(|e| e.at(v.loc))?;
|
||||
}
|
||||
RecordKey::Str(s) => {
|
||||
let tuple_fns = [
|
||||
Cmpop::Eq.op_info().method_name,
|
||||
Cmpop::NotEq.op_info().method_name,
|
||||
];
|
||||
|
||||
if !tuple_fns.into_iter().any(|op| s.to_string() == op) {
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::NoSuchField(*k, b),
|
||||
v.loc,
|
||||
));
|
||||
}
|
||||
RecordKey::Str(_) => {
|
||||
return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -238,7 +238,7 @@ impl<'a> EH_Frame<'a> {
|
||||
/// From the [specification](https://refspecs.linuxfoundation.org/LSB_5.0.0/LSB-Core-generic/LSB-Core-generic/ehframechpt.html):
|
||||
///
|
||||
/// > Each CFI record contains a Common Information Entry (CIE) record followed by 1 or more Frame
|
||||
/// > Description Entry (FDE) records.
|
||||
/// Description Entry (FDE) records.
|
||||
pub struct CFI_Record<'a> {
|
||||
// It refers to the augmentation data that corresponds to 'R' in the augmentation string
|
||||
fde_pointer_encoding: u8,
|
||||
|
@ -7,11 +7,11 @@
|
||||
#include <string.h>
|
||||
|
||||
double dbl_nan(void) {
|
||||
return NAN;
|
||||
return NAN;
|
||||
}
|
||||
|
||||
double dbl_inf(void) {
|
||||
return INFINITY;
|
||||
return INFINITY;
|
||||
}
|
||||
|
||||
void output_bool(bool x) {
|
||||
@ -19,19 +19,19 @@ void output_bool(bool x) {
|
||||
}
|
||||
|
||||
void output_int32(int32_t x) {
|
||||
printf("%" PRId32 "\n", x);
|
||||
printf("%"PRId32"\n", x);
|
||||
}
|
||||
|
||||
void output_int64(int64_t x) {
|
||||
printf("%" PRId64 "\n", x);
|
||||
printf("%"PRId64"\n", x);
|
||||
}
|
||||
|
||||
void output_uint32(uint32_t x) {
|
||||
printf("%" PRIu32 "\n", x);
|
||||
printf("%"PRIu32"\n", x);
|
||||
}
|
||||
|
||||
void output_uint64(uint64_t x) {
|
||||
printf("%" PRIu64 "\n", x);
|
||||
printf("%"PRIu64"\n", x);
|
||||
}
|
||||
|
||||
void output_float64(double x) {
|
||||
@ -52,7 +52,7 @@ void output_range(int32_t range[3]) {
|
||||
}
|
||||
|
||||
void output_asciiart(int32_t x) {
|
||||
static const char* chars = " .,-:;i+hHM$*#@ ";
|
||||
static const char *chars = " .,-:;i+hHM$*#@ ";
|
||||
if (x < 0) {
|
||||
putchar('\n');
|
||||
} else {
|
||||
@ -61,12 +61,12 @@ void output_asciiart(int32_t x) {
|
||||
}
|
||||
|
||||
struct cslice {
|
||||
void* data;
|
||||
void *data;
|
||||
size_t len;
|
||||
};
|
||||
|
||||
void output_int32_list(struct cslice* slice) {
|
||||
const int32_t* data = (int32_t*)slice->data;
|
||||
void output_int32_list(struct cslice *slice) {
|
||||
const int32_t *data = (int32_t *) slice->data;
|
||||
|
||||
putchar('[');
|
||||
for (size_t i = 0; i < slice->len; ++i) {
|
||||
@ -80,23 +80,23 @@ void output_int32_list(struct cslice* slice) {
|
||||
putchar('\n');
|
||||
}
|
||||
|
||||
void output_str(struct cslice* slice) {
|
||||
const char* data = (const char*)slice->data;
|
||||
void output_str(struct cslice *slice) {
|
||||
const char *data = (const char *) slice->data;
|
||||
|
||||
for (size_t i = 0; i < slice->len; ++i) {
|
||||
putchar(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void output_strln(struct cslice* slice) {
|
||||
void output_strln(struct cslice *slice) {
|
||||
output_str(slice);
|
||||
putchar('\n');
|
||||
}
|
||||
|
||||
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice* slice) {
|
||||
uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) {
|
||||
int i;
|
||||
void* ptr = (void*)&i;
|
||||
return (uintptr_t)ptr;
|
||||
void *ptr = (void *) &i;
|
||||
return (uintptr_t) ptr;
|
||||
}
|
||||
|
||||
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
|
||||
@ -119,12 +119,11 @@ struct Exception {
|
||||
|
||||
uint32_t __nac3_raise(struct Exception* e) {
|
||||
printf("__nac3_raise called. Exception details:\n");
|
||||
printf(" ID: %" PRIu32 "\n", e->id);
|
||||
printf(" Location: %*s:%" PRIu32 ":%" PRIu32 "\n", (int)e->file.len, (const char*)e->file.data, e->line,
|
||||
e->column);
|
||||
printf(" Function: %*s\n", (int)e->function.len, (const char*)e->function.data);
|
||||
printf(" Message: \"%*s\"\n", (int)e->message.len, (const char*)e->message.data);
|
||||
printf(" Params: {0}=%" PRId64 ", {1}=%" PRId64 ", {2}=%" PRId64 "\n", e->param[0], e->param[1], e->param[2]);
|
||||
printf(" ID: %"PRIu32"\n", e->id);
|
||||
printf(" Location: %*s:%"PRIu32":%"PRIu32"\n" , (int) e->file.len, (const char*) e->file.data, e->line, e->column);
|
||||
printf(" Function: %*s\n" , (int) e->function.len, (const char*) e->function.data);
|
||||
printf(" Message: \"%*s\"\n" , (int) e->message.len, (const char*) e->message.data);
|
||||
printf(" Params: {0}=%"PRId64", {1}=%"PRId64", {2}=%"PRId64"\n", e->param[0], e->param[1], e->param[2]);
|
||||
exit(101);
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
@ -9,7 +9,6 @@ def output_bool(x: bool):
|
||||
def example1():
|
||||
x, *ys, z = (1, 2, 3, 4, 5)
|
||||
output_int32(x)
|
||||
output_int32(len(ys))
|
||||
output_int32(ys[0])
|
||||
output_int32(ys[1])
|
||||
output_int32(ys[2])
|
||||
@ -19,14 +18,12 @@ def example2():
|
||||
x, y, *zs = (1, 2, 3, 4, 5)
|
||||
output_int32(x)
|
||||
output_int32(y)
|
||||
output_int32(len(zs))
|
||||
output_int32(zs[0])
|
||||
output_int32(zs[1])
|
||||
output_int32(zs[2])
|
||||
|
||||
def example3():
|
||||
*xs, y, z = (1, 2, 3, 4, 5)
|
||||
output_int32(len(xs))
|
||||
output_int32(xs[0])
|
||||
output_int32(xs[1])
|
||||
output_int32(xs[2])
|
||||
@ -34,12 +31,6 @@ def example3():
|
||||
output_int32(z)
|
||||
|
||||
def example4():
|
||||
*xs, y, z = (4, 5)
|
||||
output_int32(len(xs))
|
||||
output_int32(y)
|
||||
output_int32(z)
|
||||
|
||||
def example5():
|
||||
# Example from: https://docs.python.org/3/reference/simple_stmts.html#assignment-statements
|
||||
x = [0, 1]
|
||||
i = 0
|
||||
@ -53,7 +44,7 @@ class A:
|
||||
def __init__(self):
|
||||
self.value = 1000
|
||||
|
||||
def example6():
|
||||
def example5():
|
||||
ws = [88, 7, 8]
|
||||
a = A()
|
||||
x, [y, *ys, a.value], ws[0], (ws[0],) = 1, (2, False, 4, 5), 99, (6,)
|
||||
@ -72,5 +63,4 @@ def run() -> int32:
|
||||
example3()
|
||||
example4()
|
||||
example5()
|
||||
example6()
|
||||
return 0
|
||||
|
@ -6,62 +6,57 @@ def output_int32(x: int32):
|
||||
|
||||
class A:
|
||||
a: int32
|
||||
|
||||
def __init__(self, a: int32):
|
||||
self.a = a
|
||||
|
||||
def output_all_fields(self):
|
||||
output_int32(self.a)
|
||||
def __init__(self, param_a: int32):
|
||||
self.a = param_a
|
||||
|
||||
def set_a(self, a: int32):
|
||||
self.a = a
|
||||
def f1(self):
|
||||
output_int32(12)
|
||||
|
||||
def f2(self):
|
||||
output_int32(124)
|
||||
|
||||
class B(A):
|
||||
b: int32
|
||||
|
||||
def __init__(self, b: int32):
|
||||
A.__init__(self, b + 1)
|
||||
self.set_b(b)
|
||||
|
||||
def output_parent_fields(self):
|
||||
A.output_all_fields(self)
|
||||
|
||||
def output_all_fields(self):
|
||||
A.output_all_fields(self)
|
||||
output_int32(self.b)
|
||||
def __init__(self, param_a: int32, param_b: int32):
|
||||
self.a = param_a
|
||||
self.b = param_b
|
||||
|
||||
def set_b(self, b: int32):
|
||||
self.b = b
|
||||
def f3(self):
|
||||
output_int32(20)
|
||||
|
||||
def f1(self):
|
||||
output_int32(15)
|
||||
|
||||
def f2(self):
|
||||
self.b = 12
|
||||
A.f1(self)
|
||||
|
||||
class C(B):
|
||||
c: int32
|
||||
|
||||
def __init__(self, c: int32):
|
||||
B.__init__(self, c + 1)
|
||||
self.c = c
|
||||
def __init__(self, a: int32, b: int32):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def output_parent_fields(self):
|
||||
B.output_all_fields(self)
|
||||
|
||||
def output_all_fields(self):
|
||||
B.output_all_fields(self)
|
||||
output_int32(self.c)
|
||||
def f1(self):
|
||||
output_int32(17)
|
||||
|
||||
def set_c(self, c: int32):
|
||||
self.c = c
|
||||
def f3(self):
|
||||
self.a = 2
|
||||
A.f2(self)
|
||||
|
||||
def f4(self):
|
||||
A.f1(self)
|
||||
B.f2(self)
|
||||
|
||||
def run() -> int32:
|
||||
ccc = C(10)
|
||||
ccc.output_all_fields()
|
||||
ccc.set_a(1)
|
||||
ccc.set_b(2)
|
||||
ccc.set_c(3)
|
||||
ccc.output_all_fields()
|
||||
c = C(1, 2)
|
||||
c.f3()
|
||||
c.f4()
|
||||
|
||||
bbb = B(10)
|
||||
bbb.set_a(9)
|
||||
bbb.set_b(8)
|
||||
bbb.output_all_fields()
|
||||
ccc.output_all_fields()
|
||||
a = A(1)
|
||||
|
||||
output_int32(c.a)
|
||||
output_int32(c.b)
|
||||
|
||||
return 0
|
||||
|
@ -1669,7 +1669,6 @@ def run() -> int32:
|
||||
|
||||
test_ndarray_round()
|
||||
test_ndarray_floor()
|
||||
test_ndarray_ceil()
|
||||
test_ndarray_min()
|
||||
test_ndarray_minimum()
|
||||
test_ndarray_minimum_broadcast()
|
||||
|
@ -1,30 +0,0 @@
|
||||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
|
||||
def str_eq():
|
||||
output_bool("" == "")
|
||||
output_bool("a" == "")
|
||||
output_bool("a" == "b")
|
||||
output_bool("b" == "a")
|
||||
output_bool("a" == "a")
|
||||
output_bool("test string" == "test string")
|
||||
output_bool("test string1" == "test string2")
|
||||
|
||||
|
||||
def str_ne():
|
||||
output_bool("" != "")
|
||||
output_bool("a" != "")
|
||||
output_bool("a" != "b")
|
||||
output_bool("b" != "a")
|
||||
output_bool("a" != "a")
|
||||
output_bool("test string" != "test string")
|
||||
output_bool("test string1" != "test string2")
|
||||
|
||||
|
||||
def run() -> int32:
|
||||
str_eq()
|
||||
str_ne()
|
||||
|
||||
return 0
|
@ -1,7 +1,3 @@
|
||||
@extern
|
||||
def output_bool(b: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32_list(x: list[int32]):
|
||||
...
|
||||
@ -17,41 +13,6 @@ class A:
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
|
||||
def test_tuple_eq():
|
||||
# 0-len
|
||||
output_bool(() == ())
|
||||
# 1-len
|
||||
output_bool((1,) == ())
|
||||
output_bool(() == (1,))
|
||||
output_bool((1,) == (1,))
|
||||
output_bool((1,) == (2,))
|
||||
# # 2-len
|
||||
output_bool((1, 2) == ())
|
||||
output_bool(() == (1, 2))
|
||||
output_bool((1,) == (1, 2))
|
||||
output_bool((1, 2) == (1,))
|
||||
output_bool((2, 2) == (1, 2))
|
||||
output_bool((1, 2) == (2, 2))
|
||||
|
||||
|
||||
def test_tuple_ne():
|
||||
# 0-len
|
||||
output_bool(() != ())
|
||||
# 1-len
|
||||
output_bool((1,) != ())
|
||||
output_bool(() != (1,))
|
||||
output_bool((1,) != (1,))
|
||||
output_bool((1,) != (2,))
|
||||
# 2-len
|
||||
output_bool((1, 2) != ())
|
||||
output_bool(() != (1, 2))
|
||||
output_bool((1,) != (1, 2))
|
||||
output_bool((1, 2) != (1,))
|
||||
output_bool((2, 2) != (1, 2))
|
||||
output_bool((1, 2) != (2, 2))
|
||||
|
||||
|
||||
def run() -> int32:
|
||||
data = [0, 1, 2, 3]
|
||||
|
||||
@ -72,7 +33,4 @@ def run() -> int32:
|
||||
output_int32(len((1, 2, 3, 4)))
|
||||
output_int32(len((1, 2, 3, 4, 5)))
|
||||
|
||||
test_tuple_eq()
|
||||
test_tuple_ne()
|
||||
|
||||
return 0
|
@ -15,6 +15,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||
pub struct ResolverInternal {
|
||||
pub id_to_type: Mutex<HashMap<StrRef, Type>>,
|
||||
pub id_to_def: Mutex<HashMap<StrRef, DefinitionId>>,
|
||||
pub class_names: Mutex<HashMap<StrRef, Type>>,
|
||||
pub module_globals: Mutex<HashMap<StrRef, SymbolValue>>,
|
||||
pub str_store: Mutex<HashMap<String, i32>>,
|
||||
}
|
||||
|
@ -306,6 +306,7 @@ fn main() {
|
||||
let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
|
||||
id_to_type: builtins_ty.into(),
|
||||
id_to_def: builtins_def.into(),
|
||||
class_names: Mutex::default(),
|
||||
module_globals: Mutex::default(),
|
||||
str_store: Mutex::default(),
|
||||
}
|
||||
@ -313,15 +314,6 @@ fn main() {
|
||||
let resolver =
|
||||
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
let context = inkwell::context::Context::create();
|
||||
|
||||
// Process IRRT
|
||||
let irrt = load_irrt(&context, resolver.as_ref());
|
||||
if emit_llvm {
|
||||
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
|
||||
}
|
||||
|
||||
// Process the Python script
|
||||
let parser_result = parser::parse_program(&program, file_name.into()).unwrap();
|
||||
|
||||
for stmt in parser_result {
|
||||
@ -426,8 +418,8 @@ fn main() {
|
||||
registry.add_task(task);
|
||||
registry.wait_tasks_complete(handles);
|
||||
|
||||
// Link all modules together into `main`
|
||||
let buffers = membuffers.lock();
|
||||
let context = inkwell::context::Context::create();
|
||||
let main = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||
.unwrap();
|
||||
@ -447,9 +439,12 @@ fn main() {
|
||||
main.link_in_module(other).unwrap();
|
||||
}
|
||||
|
||||
let irrt = load_irrt(&context);
|
||||
if emit_llvm {
|
||||
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
|
||||
}
|
||||
main.link_in_module(irrt).unwrap();
|
||||
|
||||
// Private all functions except "run"
|
||||
let mut function_iter = main.get_first_function();
|
||||
while let Some(func) = function_iter {
|
||||
if func.count_basic_blocks() > 0 && func.get_name().to_str().unwrap() != "run" {
|
||||
@ -458,7 +453,6 @@ fn main() {
|
||||
function_iter = func.get_next_function();
|
||||
}
|
||||
|
||||
// Optimize `main`
|
||||
let target_machine = llvm_options
|
||||
.target
|
||||
.create_target_machine(llvm_options.opt_level)
|
||||
@ -472,7 +466,6 @@ fn main() {
|
||||
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
||||
}
|
||||
|
||||
// Write output
|
||||
target_machine
|
||||
.write_to_file(&main, FileType::Object, Path::new("module.o"))
|
||||
.expect("couldn't write module to file");
|
||||
|
Loading…
Reference in New Issue
Block a user