forked from M-Labs/nac3
Compare commits
10 Commits
59cad5bfe1
...
219af79017
Author | SHA1 | Date | |
---|---|---|---|
219af79017 | |||
987d17b5e8 | |||
7204513a9f | |||
9c33c4209c | |||
122983f11c | |||
71c3a65a31 | |||
8c540d1033 | |||
0cc60a3d33 | |||
a59c26aa99 | |||
02d93b11d1 |
62
Cargo.lock
generated
62
Cargo.lock
generated
@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.13"
|
||||
version = "1.1.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72db2f7947ecee9b03b510377e8bb9077afa27176fdbff55c51027e976fdcc48"
|
||||
checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6"
|
||||
dependencies = [
|
||||
"shlex",
|
||||
]
|
||||
@ -161,7 +161,7 @@ dependencies = [
|
||||
"heck 0.5.0",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -310,9 +310,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.1.0"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a"
|
||||
checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
@ -424,7 +424,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -510,9 +510,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.157"
|
||||
version = "0.2.158"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "374af5f94e54fa97cf75e945cce8a6b201e88a1a07e688b47dfd2a59c66dbd86"
|
||||
checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
|
||||
|
||||
[[package]]
|
||||
name = "libloading"
|
||||
@ -752,7 +752,7 @@ dependencies = [
|
||||
"phf_shared 0.11.2",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -856,7 +856,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-macros-backend",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -869,14 +869,14 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"pyo3-build-config",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.36"
|
||||
version = "1.0.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
|
||||
checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
@ -942,9 +942,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "redox_users"
|
||||
version = "0.4.5"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891"
|
||||
checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"libredox",
|
||||
@ -989,9 +989,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.34"
|
||||
version = "0.38.35"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f"
|
||||
checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
@ -1035,29 +1035,29 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.208"
|
||||
version = "1.0.209"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cff085d2cb684faa248efb494c39b68e522822ac0de72ccf08109abde717cfb2"
|
||||
checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.208"
|
||||
version = "1.0.209"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24008e81ff7613ed8e5ba0cfaf24e2c2f1e5b8a0495711e44fcd4882fca62bcf"
|
||||
checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.125"
|
||||
version = "1.0.127"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed"
|
||||
checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
@ -1147,7 +1147,7 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"rustversion",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1163,9 +1163,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.75"
|
||||
version = "2.0.76"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9"
|
||||
checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -1232,7 +1232,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1310,9 +1310,9 @@ checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.4"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
|
||||
checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
|
||||
|
||||
[[package]]
|
||||
name = "unicode_names2"
|
||||
@ -1510,5 +1510,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.75",
|
||||
"syn 2.0.76",
|
||||
]
|
||||
|
6
flake.lock
generated
6
flake.lock
generated
@ -2,11 +2,11 @@
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1723637854,
|
||||
"narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=",
|
||||
"lastModified": 1724819573,
|
||||
"narHash": "sha256-GnR7/ibgIH1vhoy8cYdmXE6iyZqKqFxQSVkFgosBh6w=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9",
|
||||
"rev": "71e91c409d1e654808b2621f28a327acfdad8dc2",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
@ -112,10 +112,14 @@ def extern(function):
|
||||
register_function(function)
|
||||
return function
|
||||
|
||||
def rpc(function):
|
||||
def rpc(function, flags={}):
|
||||
"""Decorates a function declaration defined by the core device runtime."""
|
||||
register_function(function)
|
||||
return function
|
||||
@wraps(function)
|
||||
def wrapped(function)
|
||||
return function
|
||||
wrapped.__artiq_rpc_flags__ = flags
|
||||
return wrapped
|
||||
|
||||
def kernel(function_or_method):
|
||||
"""Decorates a function or method to be executed on the core device."""
|
||||
|
@ -2,7 +2,7 @@ use nac3core::{
|
||||
codegen::{
|
||||
classes::{
|
||||
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
|
||||
NDArrayValue, RangeValue, UntypedArrayLikeAccessor,
|
||||
NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
|
||||
},
|
||||
expr::{destructure_range, gen_call},
|
||||
irrt::call_ndarray_calc_size,
|
||||
@ -22,7 +22,7 @@ use inkwell::{
|
||||
module::Linkage,
|
||||
types::{BasicType, IntType},
|
||||
values::{BasicValueEnum, PointerValue, StructValue},
|
||||
AddressSpace, IntPredicate,
|
||||
AddressSpace, IntPredicate, OptimizationLevel,
|
||||
};
|
||||
|
||||
use pyo3::{
|
||||
@ -32,6 +32,7 @@ use pyo3::{
|
||||
|
||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
use inkwell::values::IntValue;
|
||||
use itertools::Itertools;
|
||||
use std::{
|
||||
collections::{hash_map::DefaultHasher, HashMap},
|
||||
@ -460,8 +461,8 @@ fn format_rpc_arg<'ctx>(
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
|
||||
let llvm_arg_ty =
|
||||
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty));
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
|
||||
|
||||
let llvm_usize_sizeof = ctx
|
||||
@ -471,7 +472,7 @@ fn format_rpc_arg<'ctx>(
|
||||
let llvm_pdata_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(),
|
||||
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
||||
llvm_usize,
|
||||
"",
|
||||
)
|
||||
@ -486,13 +487,10 @@ fn format_rpc_arg<'ctx>(
|
||||
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
|
||||
|
||||
let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap();
|
||||
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap();
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
buffer.base_ptr(ctx, generator),
|
||||
ppdata,
|
||||
llvm_arg.ptr_to_data(ctx),
|
||||
llvm_pdata_sizeof,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
@ -528,6 +526,298 @@ fn format_rpc_arg<'ctx>(
|
||||
arg_slot
|
||||
}
|
||||
|
||||
/// Formats an RPC return value to conform to the expected format required by NAC3.
|
||||
fn format_rpc_ret<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ret_ty: Type,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
// -- receive value:
|
||||
// T result = {
|
||||
// void *ret_ptr = alloca(sizeof(T));
|
||||
// void *ptr = ret_ptr;
|
||||
// loop: int size = rpc_recv(ptr);
|
||||
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
||||
// if(size) { ptr = alloca(size); goto loop; }
|
||||
// else *(T*)ret_ptr
|
||||
// }
|
||||
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
|
||||
});
|
||||
|
||||
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
|
||||
ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
|
||||
return None;
|
||||
}
|
||||
|
||||
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let current_function = prehead_bb.get_parent().unwrap();
|
||||
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
||||
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
||||
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
||||
|
||||
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
|
||||
|
||||
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
// Round `val` up to its modulo `power_of_two`
|
||||
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
val: IntValue<'ctx>,
|
||||
power_of_two: IntValue<'ctx>| {
|
||||
debug_assert_eq!(
|
||||
val.get_type().get_bit_width(),
|
||||
power_of_two.get_type().get_bit_width()
|
||||
);
|
||||
|
||||
let llvm_val_t = val.get_type();
|
||||
|
||||
let max_rem = ctx
|
||||
.builder
|
||||
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
|
||||
.unwrap();
|
||||
ctx.builder
|
||||
.build_and(
|
||||
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
|
||||
ctx.builder.build_not(max_rem, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
// Setup types
|
||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
|
||||
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
|
||||
|
||||
// Allocate the resulting ndarray
|
||||
// A condition after format_rpc_ret ensures this will not be popped this off.
|
||||
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
|
||||
|
||||
// Setup ndims
|
||||
let ndims =
|
||||
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
|
||||
assert_eq!(values.len(), 1);
|
||||
|
||||
u64::try_from(values[0].clone()).unwrap()
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
// Set `ndarray.ndims`
|
||||
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
|
||||
// Allocate `ndarray.shape` [size_t; ndims]
|
||||
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
|
||||
|
||||
/*
|
||||
ndarray now:
|
||||
- .ndims: initialized
|
||||
- .shape: allocated but uninitialized .shape
|
||||
- .data: uninitialized
|
||||
*/
|
||||
|
||||
let llvm_usize_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
|
||||
.unwrap();
|
||||
let llvm_pdata_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(
|
||||
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
|
||||
llvm_usize,
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
let llvm_elem_sizeof = ctx
|
||||
.builder
|
||||
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
|
||||
.unwrap();
|
||||
|
||||
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
|
||||
// (4 + 4 * ndims) bytes with 8-byte alignment
|
||||
let sizeof_dims =
|
||||
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
|
||||
let unaligned_buffer_size =
|
||||
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
|
||||
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
|
||||
|
||||
let stackptr = call_stacksave(ctx, None);
|
||||
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
|
||||
let buffer = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
llvm_i8_8,
|
||||
ctx.builder
|
||||
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
|
||||
.unwrap(),
|
||||
"rpc.buffer",
|
||||
)
|
||||
.unwrap();
|
||||
let buffer = ctx
|
||||
.builder
|
||||
.build_bitcast(buffer, llvm_pi8, "")
|
||||
.map(BasicValueEnum::into_pointer_value)
|
||||
.unwrap();
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
|
||||
|
||||
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
||||
//
|
||||
// The returned value is the number of bytes for `ndarray.data`.
|
||||
let ndarray_nbytes = ctx
|
||||
.build_call_or_invoke(
|
||||
rpc_recv,
|
||||
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
|
||||
"rpc.size.next",
|
||||
)
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
// debug_assert(ndarray_nbytes > 0)
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::UGT,
|
||||
ndarray_nbytes,
|
||||
ndarray_nbytes.get_type().const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap(),
|
||||
"0:AssertionError",
|
||||
"Unexpected RPC termination for ndarray - Expected data buffer next",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
// Copy shape from the buffer to `ndarray.shape`.
|
||||
let pbuffer_dims =
|
||||
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
|
||||
|
||||
call_memcpy_generic(
|
||||
ctx,
|
||||
ndarray.dim_sizes().base_ptr(ctx, generator),
|
||||
pbuffer_dims,
|
||||
sizeof_dims,
|
||||
llvm_i1.const_zero(),
|
||||
);
|
||||
// Restore stack from before allocation of buffer
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
// Allocate `ndarray.data`.
|
||||
// `ndarray.shape` must be initialized beforehand in this implementation
|
||||
// (for ndarray.create_data() to know how many elements to allocate)
|
||||
let num_elements =
|
||||
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
|
||||
|
||||
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let sizeof_data =
|
||||
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::UGE,
|
||||
sizeof_data,
|
||||
ndarray_nbytes,
|
||||
"",
|
||||
).unwrap(),
|
||||
"0:AssertionError",
|
||||
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
|
||||
[Some(sizeof_data), Some(ndarray_nbytes), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
|
||||
|
||||
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
|
||||
let ndarray_data_i8 =
|
||||
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
|
||||
|
||||
// NOTE: Currently on `prehead_bb`
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
|
||||
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.unwrap();
|
||||
|
||||
let is_done = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
||||
.unwrap();
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
// Align the allocation to sizeof(T)
|
||||
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
|
||||
let alloc_ptr = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
llvm_elem_ty,
|
||||
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
|
||||
"rpc.alloc",
|
||||
)
|
||||
.unwrap();
|
||||
let alloc_ptr =
|
||||
ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
ndarray.as_base_value().into()
|
||||
}
|
||||
|
||||
_ => {
|
||||
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
|
||||
let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let is_done = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
|
||||
let alloc_ptr =
|
||||
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
|
||||
let alloc_ptr =
|
||||
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
ctx.builder.build_load(slot, "rpc.result").unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
@ -541,6 +831,9 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
let ptr_type = int8.ptr_type(AddressSpace::default());
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
// println!("obj: {:?}", obj);
|
||||
println!("fun: {:?}", fun);
|
||||
|
||||
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
||||
// -- setup rpc tags
|
||||
let mut tag = Vec::new();
|
||||
@ -663,63 +956,14 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
// reclaim stack space used by arguments
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
// -- receive value:
|
||||
// T result = {
|
||||
// void *ret_ptr = alloca(sizeof(T));
|
||||
// void *ptr = ret_ptr;
|
||||
// loop: int size = rpc_recv(ptr);
|
||||
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
|
||||
// if(size) { ptr = alloca(size); goto loop; }
|
||||
// else *(T*)ret_ptr
|
||||
// }
|
||||
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
|
||||
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
|
||||
});
|
||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||
|
||||
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) {
|
||||
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let current_function = prehead_bb.get_parent().unwrap();
|
||||
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
||||
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
|
||||
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
|
||||
|
||||
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
|
||||
let need_load = !ret_ty.is_pointer_type();
|
||||
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
|
||||
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
ctx.builder.position_at_end(head_bb);
|
||||
|
||||
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&slotgen, prehead_bb)]);
|
||||
let alloc_size = ctx
|
||||
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let is_done = ctx
|
||||
.builder
|
||||
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
|
||||
ctx.builder.position_at_end(alloc_bb);
|
||||
|
||||
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
|
||||
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
|
||||
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
|
||||
ctx.builder.build_unconditional_branch(head_bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(tail_bb);
|
||||
|
||||
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
|
||||
if need_load {
|
||||
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||
// An RPC returning an NDArray would not touch here.
|
||||
call_stackrestore(ctx, stackptr);
|
||||
}
|
||||
Ok(Some(result))
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn attributes_writeback(
|
||||
@ -810,6 +1054,7 @@ pub fn attributes_writeback(
|
||||
.collect(),
|
||||
ret: ctx.primitives.none,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
};
|
||||
let args: Vec<_> =
|
||||
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
||||
|
@ -191,14 +191,30 @@ impl Nac3 {
|
||||
})
|
||||
.unwrap()
|
||||
});
|
||||
body.retain(|stmt| {
|
||||
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node {
|
||||
decorator_list.iter().any(|decorator| {
|
||||
body.retain_mut(|stmt| {
|
||||
if let StmtKind::FunctionDef { ref mut decorator_list, .. } = stmt.node {
|
||||
decorator_list.iter_mut().any(|decorator| {
|
||||
if let ExprKind::Name { id, .. } = decorator.node {
|
||||
id.to_string() == "kernel"
|
||||
|| id.to_string() == "portable"
|
||||
|| id.to_string() == "rpc"
|
||||
} else if let ExprKind::Call { func, .. } = &decorator.node {
|
||||
// decorators with flags (e.g. rpc async) have Call for the node;
|
||||
// this is to remove the middle
|
||||
if let ExprKind::Name { id, .. } = func.node {
|
||||
if id.to_string() == "rpc" {
|
||||
println!("found rpc: {:?}", func);
|
||||
println!("decorator node: {:?}", decorator.node);
|
||||
decorator.node = func.clone().node;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
else {
|
||||
false
|
||||
}
|
||||
})
|
||||
@ -318,6 +334,7 @@ impl Nac3 {
|
||||
}],
|
||||
ret: primitives.none,
|
||||
vars: into_var_map([arg_ty]),
|
||||
opts: vec![],
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
||||
gen_core_log(ctx, &obj, fun, &args, generator)?;
|
||||
@ -348,6 +365,7 @@ impl Nac3 {
|
||||
],
|
||||
ret: primitives.none,
|
||||
vars: into_var_map([arg_ty]),
|
||||
opts: vec![],
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
|
||||
gen_rtio_log(ctx, &obj, fun, &args, generator)?;
|
||||
@ -495,7 +513,7 @@ impl Nac3 {
|
||||
class_name, stmt.location
|
||||
)));
|
||||
}
|
||||
rpc_ids.push((Some((class_obj.clone(), *name)), def_id));
|
||||
rpc_ids.push((Some((class_obj.clone(), *name)), def_id, Some(FuncFlags::Async)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -560,7 +578,7 @@ impl Nac3 {
|
||||
let irrt = load_irrt(&context, resolver.as_ref());
|
||||
|
||||
let fun_signature =
|
||||
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
|
||||
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new(), opts: vec![] };
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
let mut cache = HashMap::new();
|
||||
let signature = store.from_signature(
|
||||
@ -598,7 +616,7 @@ impl Nac3 {
|
||||
{
|
||||
let rpc_codegen = rpc_codegen_callback();
|
||||
let defs = top_level.definitions.read();
|
||||
for (class_data, id) in &rpc_ids {
|
||||
for (class_data, id, flags) in &rpc_ids {
|
||||
let mut def = defs[id.0].write();
|
||||
match &mut *def {
|
||||
TopLevelDef::Function { codegen_callback, .. } => {
|
||||
@ -917,7 +935,7 @@ impl Nac3 {
|
||||
let builtins = vec![
|
||||
(
|
||||
"now_mu".into(),
|
||||
FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new() },
|
||||
FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new(), opts: vec![] },
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
|
||||
Ok(Some(time_fns.emit_now_mu(ctx)))
|
||||
}))),
|
||||
@ -933,6 +951,7 @@ impl Nac3 {
|
||||
}],
|
||||
ret: primitive.none,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
@ -953,6 +972,7 @@ impl Nac3 {
|
||||
}],
|
||||
ret: primitive.none,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
},
|
||||
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
|
@ -1250,11 +1250,13 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||
|
||||
/// Returns the element type of this `ndarray` type.
|
||||
#[must_use]
|
||||
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
|
||||
pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
|
||||
self.as_base_type()
|
||||
.get_element_type()
|
||||
.into_struct_type()
|
||||
.get_field_type_at_index(2)
|
||||
.map(BasicTypeEnum::into_pointer_type)
|
||||
.map(PointerType::get_element_type)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
@ -1404,7 +1406,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||
|
||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||
/// on the field.
|
||||
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
||||
|
||||
|
@ -297,6 +297,7 @@ impl ConcreteTypeStore {
|
||||
let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
|
||||
TypeVar { id, ty }
|
||||
})),
|
||||
opts: vec![],
|
||||
}),
|
||||
ConcreteTypeEnum::TLiteral { values, .. } => {
|
||||
TypeEnum::TLiteral { values: values.clone(), loc: None }
|
||||
|
@ -2385,7 +2385,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
})
|
||||
.map(BasicValueEnum::into_int_value)?;
|
||||
|
||||
Ok(ctx.builder.build_not(cmp, "").unwrap())
|
||||
Ok(ctx.builder.build_not(
|
||||
generator.bool_to_i1(ctx, cmp),
|
||||
"",
|
||||
).unwrap())
|
||||
},
|
||||
|_, ctx| {
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
|
@ -174,6 +174,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
}
|
||||
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
|
||||
|
||||
// Perform i1 <-> i8 conversion as needed
|
||||
let val = if ctx.unifier.unioned(target.custom.unwrap(), ctx.primitives.bool) {
|
||||
generator.bool_to_i8(ctx, val.into_int_value()).into()
|
||||
} else {
|
||||
val
|
||||
};
|
||||
|
||||
ctx.builder.build_store(ptr, val).unwrap();
|
||||
}
|
||||
};
|
||||
@ -1665,6 +1673,23 @@ pub fn gen_return<G: CodeGenerator>(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Remap boolean return type into i1
|
||||
let value = value.map(|ret_val| {
|
||||
// The "return type" of a sret function is in the first parameter
|
||||
let expected_ty = if ctx.need_sret {
|
||||
func.get_type().get_param_types()[0]
|
||||
} else {
|
||||
func.get_type().get_return_type().unwrap()
|
||||
};
|
||||
|
||||
if matches!(expected_ty, BasicTypeEnum::IntType(ty) if ty.get_bit_width() == 1) {
|
||||
generator.bool_to_i1(ctx, ret_val.into_int_value()).into()
|
||||
} else {
|
||||
ret_val
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(return_target) = ctx.return_target {
|
||||
if let Some(value) = value {
|
||||
ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap();
|
||||
@ -1675,25 +1700,6 @@ pub fn gen_return<G: CodeGenerator>(
|
||||
ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap();
|
||||
ctx.builder.build_return(None).unwrap();
|
||||
} else {
|
||||
// Remap boolean return type into i1
|
||||
let value = value.map(|v| {
|
||||
let expected_ty = func.get_type().get_return_type().unwrap();
|
||||
let ret_val = v.as_basic_value_enum();
|
||||
|
||||
if expected_ty.is_int_type() && ret_val.is_int_value() {
|
||||
let ret_type = expected_ty.into_int_type();
|
||||
let ret_val = ret_val.into_int_value();
|
||||
|
||||
if ret_type.get_bit_width() == 1 && ret_val.get_type().get_bit_width() != 1 {
|
||||
generator.bool_to_i1(ctx, ret_val)
|
||||
} else {
|
||||
ret_val
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
ret_val
|
||||
}
|
||||
});
|
||||
let value = value.as_ref().map(|v| v as &dyn BasicValue);
|
||||
ctx.builder.build_return(value).unwrap();
|
||||
}
|
||||
|
@ -124,6 +124,7 @@ fn test_primitives() {
|
||||
],
|
||||
ret: primitives.int32,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
};
|
||||
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
@ -273,6 +274,7 @@ fn test_simple_call() {
|
||||
}],
|
||||
ret: primitives.int32,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
};
|
||||
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
|
@ -77,6 +77,7 @@ pub fn get_exn_constructor(
|
||||
args: exn_cons_args,
|
||||
ret: exn_type,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
}));
|
||||
let fun_def = TopLevelDef::Function {
|
||||
name: format!("{name}.__init__"),
|
||||
@ -137,6 +138,7 @@ fn create_fn_by_codegen(
|
||||
.collect(),
|
||||
ret: ret_ty,
|
||||
vars: var_map.clone(),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -670,6 +672,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
],
|
||||
ret: range,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
}))
|
||||
};
|
||||
|
||||
@ -925,6 +928,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}],
|
||||
ret: self.primitives.option,
|
||||
vars: into_var_map([self.option_tvar]),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: vec![self.option_tvar.id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1060,6 +1064,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}],
|
||||
ret: self.num_or_ndarray_ty.ty,
|
||||
vars: self.num_or_ndarray_var_map.clone(),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1297,6 +1302,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
],
|
||||
ret: ndarray,
|
||||
vars: into_var_map([tv]),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: vec![tv.id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1356,6 +1362,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
],
|
||||
ret: self.ndarray_float_2d,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1403,6 +1410,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}],
|
||||
ret: str,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1479,6 +1487,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}],
|
||||
ret: self.primitives.int32,
|
||||
vars: into_var_map([arg_tvar]),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1520,6 +1529,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
],
|
||||
ret: self.num_ty.ty,
|
||||
vars: self.num_var_map.clone(),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1607,6 +1617,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
.collect(),
|
||||
ret: ret_ty.ty,
|
||||
vars: into_var_map([x1_ty, x2_ty, ret_ty]),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: vec![x1_ty.id, x2_ty.id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1648,6 +1659,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
}],
|
||||
ret: self.num_or_ndarray_ty.ty,
|
||||
vars: self.num_or_ndarray_var_map.clone(),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
@ -1842,6 +1854,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
.collect(),
|
||||
ret: ret_ty.ty,
|
||||
vars: into_var_map([x1_ty, x2_ty, ret_ty]),
|
||||
opts: vec![],
|
||||
})),
|
||||
var_id: vec![ret_ty.id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
|
@ -1122,6 +1122,7 @@ impl TopLevelComposer {
|
||||
args: arg_types,
|
||||
ret: return_ty,
|
||||
vars: function_var_map,
|
||||
opts: vec![],
|
||||
}));
|
||||
unifier.unify(*dummy_ty, function_ty).map_err(|e| {
|
||||
HashSet::from([e.at(Some(function_ast.location)).to_display(unifier).to_string()])
|
||||
@ -1386,6 +1387,7 @@ impl TopLevelComposer {
|
||||
args: arg_types,
|
||||
ret: ret_type,
|
||||
vars: method_var_map,
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
// unify now since function type is not in type annotation define
|
||||
@ -1673,7 +1675,7 @@ impl TopLevelComposer {
|
||||
// they may be changed with our use of placeholders
|
||||
for (def, _) in definition_ast_list.iter().skip(self.builtin_num) {
|
||||
if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() {
|
||||
if let TypeEnum::TFunc(FunSignature { args, ret, vars }) =
|
||||
if let TypeEnum::TFunc(FunSignature { args, ret, vars, opts }) =
|
||||
unifier.get_ty(*signature).as_ref()
|
||||
{
|
||||
let new_var_ids = vars
|
||||
@ -1692,6 +1694,7 @@ impl TopLevelComposer {
|
||||
.zip(vars.values())
|
||||
.map(|(id, v)| (*id, *v))
|
||||
.collect(),
|
||||
opts: opts.clone(),
|
||||
};
|
||||
unifier
|
||||
.unification_table
|
||||
@ -1760,6 +1763,7 @@ impl TopLevelComposer {
|
||||
],
|
||||
ret: self_type,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
}));
|
||||
let cons_fun = TopLevelDef::Function {
|
||||
name: format!("{}.{}", class_name, "__init__"),
|
||||
@ -1807,6 +1811,7 @@ impl TopLevelComposer {
|
||||
args: contor_args,
|
||||
ret: self_type,
|
||||
vars: contor_type_vars,
|
||||
opts: vec![],
|
||||
}));
|
||||
unifier.unify(constructor.unwrap(), contor_type).map_err(|e| {
|
||||
HashSet::from([e
|
||||
@ -1894,7 +1899,7 @@ impl TopLevelComposer {
|
||||
} = &mut *function_def
|
||||
{
|
||||
let signature_ty_enum = unifier.get_ty(*signature);
|
||||
let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref()
|
||||
let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = signature_ty_enum.as_ref()
|
||||
else {
|
||||
unreachable!("must be typeenum::tfunc")
|
||||
};
|
||||
|
@ -461,11 +461,13 @@ impl TopLevelComposer {
|
||||
args: vec![],
|
||||
ret: bool,
|
||||
vars: into_var_map([option_type_var]),
|
||||
opts: vec![],
|
||||
}));
|
||||
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: option_type_var.ty,
|
||||
vars: into_var_map([option_type_var]),
|
||||
opts: vec![],
|
||||
}));
|
||||
let option = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PrimDef::Option.id(),
|
||||
@ -500,6 +502,7 @@ impl TopLevelComposer {
|
||||
args: vec![],
|
||||
ret: ndarray_copy_fun_ret_ty.ty,
|
||||
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
opts: vec![],
|
||||
}));
|
||||
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg {
|
||||
@ -510,6 +513,7 @@ impl TopLevelComposer {
|
||||
}],
|
||||
ret: none,
|
||||
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
|
||||
opts: vec![],
|
||||
}));
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PrimDef::NDArray.id(),
|
||||
|
@ -199,6 +199,7 @@ pub fn impl_binop(
|
||||
name: "other".into(),
|
||||
is_vararg: false,
|
||||
}],
|
||||
opts: vec![],
|
||||
})),
|
||||
false,
|
||||
)
|
||||
@ -219,6 +220,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
|
||||
ret: ret_ty,
|
||||
vars: VarMap::new(),
|
||||
args: vec![],
|
||||
opts: vec![],
|
||||
})),
|
||||
false,
|
||||
),
|
||||
@ -264,6 +266,7 @@ pub fn impl_cmpop(
|
||||
name: "other".into(),
|
||||
is_vararg: false,
|
||||
}],
|
||||
opts: vec![],
|
||||
})),
|
||||
false,
|
||||
),
|
||||
|
@ -464,12 +464,14 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||
|var| var.custom.unwrap(),
|
||||
),
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
});
|
||||
let enter = self.unifier.add_ty(enter);
|
||||
let exit = TypeEnum::TFunc(FunSignature {
|
||||
args: vec![],
|
||||
ret: self.unifier.get_dummy_var().ty,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
});
|
||||
let exit = self.unifier.add_ty(exit);
|
||||
let mut fields = HashMap::new();
|
||||
@ -766,6 +768,7 @@ impl<'a> Inferencer<'a> {
|
||||
.collect(),
|
||||
ret,
|
||||
vars: VarMap::default(),
|
||||
opts: vec![],
|
||||
};
|
||||
let body = new_context.fold_expr(body)?;
|
||||
new_context.unify(fun.ret, body.custom.unwrap(), &location)?;
|
||||
@ -1107,6 +1110,7 @@ impl<'a> Inferencer<'a> {
|
||||
}],
|
||||
ret: ret_ty,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1166,6 +1170,7 @@ impl<'a> Inferencer<'a> {
|
||||
}],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1214,6 +1219,7 @@ impl<'a> Inferencer<'a> {
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1253,6 +1259,7 @@ impl<'a> Inferencer<'a> {
|
||||
}],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1365,6 +1372,7 @@ impl<'a> Inferencer<'a> {
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1445,6 +1453,7 @@ impl<'a> Inferencer<'a> {
|
||||
}],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1487,6 +1496,7 @@ impl<'a> Inferencer<'a> {
|
||||
}],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1532,6 +1542,7 @@ impl<'a> Inferencer<'a> {
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1586,6 +1597,7 @@ impl<'a> Inferencer<'a> {
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
@ -1654,6 +1666,7 @@ impl<'a> Inferencer<'a> {
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
|
@ -91,6 +91,7 @@ impl TestEnvironment {
|
||||
}],
|
||||
ret: int32,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
fields.insert("__add__".into(), (add_ty, false));
|
||||
});
|
||||
@ -237,6 +238,7 @@ impl TestEnvironment {
|
||||
}],
|
||||
ret: int32,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
fields.insert("__add__".into(), (add_ty, false));
|
||||
});
|
||||
@ -386,6 +388,7 @@ impl TestEnvironment {
|
||||
args: vec![],
|
||||
ret: foo_ty,
|
||||
vars: into_var_map([tvar]),
|
||||
opts: vec![],
|
||||
})),
|
||||
);
|
||||
|
||||
@ -393,6 +396,7 @@ impl TestEnvironment {
|
||||
args: vec![],
|
||||
ret: int32,
|
||||
vars: IndexMap::default(),
|
||||
opts: vec![],
|
||||
}));
|
||||
let bar = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(defs + 2),
|
||||
@ -420,6 +424,7 @@ impl TestEnvironment {
|
||||
args: vec![],
|
||||
ret: bar,
|
||||
vars: IndexMap::default(),
|
||||
opts: vec![],
|
||||
})),
|
||||
);
|
||||
|
||||
@ -449,6 +454,7 @@ impl TestEnvironment {
|
||||
args: vec![],
|
||||
ret: bar2,
|
||||
vars: IndexMap::default(),
|
||||
opts: vec![],
|
||||
})),
|
||||
);
|
||||
let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into();
|
||||
|
@ -123,11 +123,17 @@ impl FuncArg {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum FunOption {
|
||||
Async,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FunSignature {
|
||||
pub args: Vec<FuncArg>,
|
||||
pub ret: Type,
|
||||
pub vars: VarMap,
|
||||
pub opts: Vec<FunOption>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
@ -1565,7 +1571,7 @@ impl Unifier {
|
||||
None
|
||||
}
|
||||
}
|
||||
TypeEnum::TFunc(FunSignature { args, ret, vars: params }) => {
|
||||
TypeEnum::TFunc(FunSignature { args, ret, vars: params, opts }) => {
|
||||
let new_params = self.subst_map(params, mapping, cache);
|
||||
let new_ret = self.subst_impl(*ret, mapping, cache);
|
||||
let mut new_args = Cow::from(args);
|
||||
@ -1580,7 +1586,8 @@ impl Unifier {
|
||||
let params = new_params.unwrap_or_else(|| params.clone());
|
||||
let ret = new_ret.unwrap_or(*ret);
|
||||
let args = new_args.into_owned();
|
||||
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params })))
|
||||
let opts = opts.clone();
|
||||
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params, opts })))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -389,6 +389,7 @@ fn test_virtual() {
|
||||
args: vec![],
|
||||
ret: int,
|
||||
vars: VarMap::new(),
|
||||
opts: vec![],
|
||||
}));
|
||||
let bar = env.unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(5),
|
||||
|
@ -360,7 +360,7 @@ fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
let signature = FunSignature { args: vec![], ret: primitive.int32, vars: VarMap::new() };
|
||||
let signature = FunSignature { args: vec![], ret: primitive.int32, vars: VarMap::new(), opts: vec![] };
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
let mut cache = HashMap::new();
|
||||
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);
|
||||
|
Loading…
Reference in New Issue
Block a user