1
0
forked from M-Labs/nac3

Compare commits

..

10 Commits

Author SHA1 Message Date
219af79017 FunSignature: add "opts" field 2024-09-10 17:22:59 +08:00
987d17b5e8 nac3artiq: detect decorators with arguments 2024-09-10 17:22:19 +08:00
7204513a9f debug 2024-09-06 16:33:58 +08:00
9c33c4209c [core] Fix type of ndarray.element_type
Should be the element type of the NDArray itself, not the pointer to its
type.
2024-08-30 22:47:38 +08:00
122983f11c flake: update dependencies 2024-08-30 14:45:38 +08:00
71c3a65a31 [core] codegen/stmt: Fix obtaining return type of sret functions 2024-08-29 19:15:30 +08:00
8c540d1033 [core] codegen/stmt: Add more casts for boolean types 2024-08-29 16:36:32 +08:00
0cc60a3d33 [core] codegen/expr: Fix missing cast to i1 2024-08-29 16:36:32 +08:00
a59c26aa99 [artiq] Fix RPC of ndarrays from host 2024-08-29 16:08:45 +08:00
02d93b11d1 [meta] Update dependencies 2024-08-29 14:32:21 +08:00
19 changed files with 468 additions and 133 deletions

62
Cargo.lock generated
View File

@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.13" version = "1.1.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72db2f7947ecee9b03b510377e8bb9077afa27176fdbff55c51027e976fdcc48" checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6"
dependencies = [ dependencies = [
"shlex", "shlex",
] ]
@ -161,7 +161,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -310,9 +310,9 @@ dependencies = [
[[package]] [[package]]
name = "fastrand" name = "fastrand"
version = "2.1.0" version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6"
[[package]] [[package]]
name = "fixedbitset" name = "fixedbitset"
@ -424,7 +424,7 @@ checksum = "4fa4d8d74483041a882adaa9a29f633253a66dde85055f0495c121620ac484b2"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -510,9 +510,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.157" version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374af5f94e54fa97cf75e945cce8a6b201e88a1a07e688b47dfd2a59c66dbd86" checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]] [[package]]
name = "libloading" name = "libloading"
@ -752,7 +752,7 @@ dependencies = [
"phf_shared 0.11.2", "phf_shared 0.11.2",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -856,7 +856,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-macros-backend", "pyo3-macros-backend",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -869,14 +869,14 @@ dependencies = [
"proc-macro2", "proc-macro2",
"pyo3-build-config", "pyo3-build-config",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.36" version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
] ]
@ -942,9 +942,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_users" name = "redox_users"
version = "0.4.5" version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [ dependencies = [
"getrandom", "getrandom",
"libredox", "libredox",
@ -989,9 +989,9 @@ dependencies = [
[[package]] [[package]]
name = "rustix" name = "rustix"
version = "0.38.34" version = "0.38.35"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" checksum = "a85d50532239da68e9addb745ba38ff4612a242c1c7ceea689c4bc7c2f43c36f"
dependencies = [ dependencies = [
"bitflags", "bitflags",
"errno", "errno",
@ -1035,29 +1035,29 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.208" version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff085d2cb684faa248efb494c39b68e522822ac0de72ccf08109abde717cfb2" checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09"
dependencies = [ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]] [[package]]
name = "serde_derive" name = "serde_derive"
version = "1.0.208" version = "1.0.209"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24008e81ff7613ed8e5ba0cfaf24e2c2f1e5b8a0495711e44fcd4882fca62bcf" checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
name = "serde_json" name = "serde_json"
version = "1.0.125" version = "1.0.127"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad"
dependencies = [ dependencies = [
"itoa", "itoa",
"memchr", "memchr",
@ -1147,7 +1147,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"rustversion", "rustversion",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -1163,9 +1163,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.75" version = "2.0.76"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" checksum = "578e081a14e0cefc3279b0472138c513f37b41a08d5a3cca9b6e4e8ceb6cd525"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -1232,7 +1232,7 @@ checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]
[[package]] [[package]]
@ -1310,9 +1310,9 @@ checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.4" version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
[[package]] [[package]]
name = "unicode_names2" name = "unicode_names2"
@ -1510,5 +1510,5 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.75", "syn 2.0.76",
] ]

6
flake.lock generated
View File

@ -2,11 +2,11 @@
"nodes": { "nodes": {
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1723637854, "lastModified": 1724819573,
"narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=", "narHash": "sha256-GnR7/ibgIH1vhoy8cYdmXE6iyZqKqFxQSVkFgosBh6w=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9", "rev": "71e91c409d1e654808b2621f28a327acfdad8dc2",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -112,10 +112,14 @@ def extern(function):
register_function(function) register_function(function)
return function return function
def rpc(function): def rpc(function, flags={}):
"""Decorates a function declaration defined by the core device runtime.""" """Decorates a function declaration defined by the core device runtime."""
register_function(function) register_function(function)
return function @wraps(function)
def wrapped(function)
return function
wrapped.__artiq_rpc_flags__ = flags
return wrapped
def kernel(function_or_method): def kernel(function_or_method):
"""Decorates a function or method to be executed on the core device.""" """Decorates a function or method to be executed on the core device."""

View File

@ -2,7 +2,7 @@ use nac3core::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType,
NDArrayValue, RangeValue, UntypedArrayLikeAccessor, NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor,
}, },
expr::{destructure_range, gen_call}, expr::{destructure_range, gen_call},
irrt::call_ndarray_calc_size, irrt::call_ndarray_calc_size,
@ -22,7 +22,7 @@ use inkwell::{
module::Linkage, module::Linkage,
types::{BasicType, IntType}, types::{BasicType, IntType},
values::{BasicValueEnum, PointerValue, StructValue}, values::{BasicValueEnum, PointerValue, StructValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use pyo3::{ use pyo3::{
@ -32,6 +32,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use inkwell::values::IntValue;
use itertools::Itertools; use itertools::Itertools;
use std::{ use std::{
collections::{hash_map::DefaultHasher, HashMap}, collections::{hash_map::DefaultHasher, HashMap},
@ -460,8 +461,8 @@ fn format_rpc_arg<'ctx>(
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty);
let llvm_arg_ty = let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
NDArrayType::new(generator, ctx.ctx, ctx.get_llvm_type(generator, elem_ty)); let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let llvm_usize_sizeof = ctx let llvm_usize_sizeof = ctx
@ -471,7 +472,7 @@ fn format_rpc_arg<'ctx>(
let llvm_pdata_sizeof = ctx let llvm_pdata_sizeof = ctx
.builder .builder
.build_int_truncate_or_bit_cast( .build_int_truncate_or_bit_cast(
llvm_arg_ty.element_type().ptr_type(AddressSpace::default()).size_of(), llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize, llvm_usize,
"", "",
) )
@ -486,13 +487,10 @@ fn format_rpc_arg<'ctx>(
let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg"));
let ppdata = generator.gen_var_alloc(ctx, llvm_arg_ty.element_type(), None).unwrap();
ctx.builder.build_store(ppdata, llvm_arg.data().base_ptr(ctx, generator)).unwrap();
call_memcpy_generic( call_memcpy_generic(
ctx, ctx,
buffer.base_ptr(ctx, generator), buffer.base_ptr(ctx, generator),
ppdata, llvm_arg.ptr_to_data(ctx),
llvm_pdata_sizeof, llvm_pdata_sizeof,
llvm_i1.const_zero(), llvm_i1.const_zero(),
); );
@ -528,6 +526,298 @@ fn format_rpc_arg<'ctx>(
arg_slot arg_slot
} }
/// Formats an RPC return value to conform to the expected format required by NAC3.
fn format_rpc_ret<'ctx>(
generator: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
ret_ty: Type,
) -> Option<BasicValueEnum<'ctx>> {
// -- receive value:
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let llvm_i8 = ctx.ctx.i8_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false);
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None)
});
if ctx.unifier.unioned(ret_ty, ctx.primitives.none) {
ctx.build_call_or_invoke(rpc_recv, &[llvm_pi8.const_null().into()], "rpc_recv");
return None;
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let llvm_ret_ty = ctx.get_llvm_abi_type(generator, ret_ty);
let result = match &*ctx.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i1 = ctx.ctx.bool_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
// Round `val` up to its modulo `power_of_two`
let round_up = |ctx: &mut CodeGenContext<'ctx, '_>,
val: IntValue<'ctx>,
power_of_two: IntValue<'ctx>| {
debug_assert_eq!(
val.get_type().get_bit_width(),
power_of_two.get_type().get_bit_width()
);
let llvm_val_t = val.get_type();
let max_rem = ctx
.builder
.build_int_sub(power_of_two, llvm_val_t.const_int(1, false), "")
.unwrap();
ctx.builder
.build_and(
ctx.builder.build_int_add(val, max_rem, "").unwrap(),
ctx.builder.build_not(max_rem, "").unwrap(),
"",
)
.unwrap()
};
// Setup types
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty);
let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty);
// Allocate the resulting ndarray
// A condition after format_rpc_ret ensures this will not be popped this off.
let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result"));
// Setup ndims
let ndims =
if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) {
assert_eq!(values.len(), 1);
u64::try_from(values[0].clone()).unwrap()
} else {
unreachable!();
};
// Set `ndarray.ndims`
ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false));
// Allocate `ndarray.shape` [size_t; ndims]
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx));
/*
ndarray now:
- .ndims: initialized
- .shape: allocated but uninitialized .shape
- .data: uninitialized
*/
let llvm_usize_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "")
.unwrap();
let llvm_pdata_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(
llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(),
llvm_usize,
"",
)
.unwrap();
let llvm_elem_sizeof = ctx
.builder
.build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "")
.unwrap();
// Allocates a buffer for the initial RPC'ed object, which is guaranteed to be
// (4 + 4 * ndims) bytes with 8-byte alignment
let sizeof_dims =
ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap();
let unaligned_buffer_size =
ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap();
let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false));
let stackptr = call_stacksave(ctx, None);
// Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment
let buffer = ctx
.builder
.build_array_alloca(
llvm_i8_8,
ctx.builder
.build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "")
.unwrap(),
"rpc.buffer",
)
.unwrap();
let buffer = ctx
.builder
.build_bitcast(buffer, llvm_pi8, "")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None);
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
//
// The returned value is the number of bytes for `ndarray.data`.
let ndarray_nbytes = ctx
.build_call_or_invoke(
rpc_recv,
&[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims].
"rpc.size.next",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
// debug_assert(ndarray_nbytes > 0)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::UGT,
ndarray_nbytes,
ndarray_nbytes.get_type().const_zero(),
"",
)
.unwrap(),
"0:AssertionError",
"Unexpected RPC termination for ndarray - Expected data buffer next",
[None, None, None],
ctx.current_loc,
);
}
// Copy shape from the buffer to `ndarray.shape`.
let pbuffer_dims =
unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) };
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
pbuffer_dims,
sizeof_dims,
llvm_i1.const_zero(),
);
// Restore stack from before allocation of buffer
call_stackrestore(ctx, stackptr);
// Allocate `ndarray.data`.
// `ndarray.shape` must be initialized beforehand in this implementation
// (for ndarray.create_data() to know how many elements to allocate)
let num_elements =
call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None));
// debug_assert(nelems * sizeof(T) >= ndarray_nbytes)
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let sizeof_data =
ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap();
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::UGE,
sizeof_data,
ndarray_nbytes,
"",
).unwrap(),
"0:AssertionError",
"Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes",
[Some(sizeof_data), Some(ndarray_nbytes), None],
ctx.current_loc,
);
}
ndarray.create_data(ctx, llvm_elem_ty, num_elements);
let ndarray_data = ndarray.data().base_ptr(ctx, generator);
let ndarray_data_i8 =
ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap();
// NOTE: Currently on `prehead_bb`
ctx.builder.build_unconditional_branch(head_bb).unwrap();
// Inserting into `head_bb`. Do `rpc_recv` for `data` recursively.
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.map(BasicValueEnum::into_int_value)
.unwrap();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
// Align the allocation to sizeof(T)
let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof);
let alloc_ptr = ctx
.builder
.build_array_alloca(
llvm_elem_ty,
ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(),
"rpc.alloc",
)
.unwrap();
let alloc_ptr =
ctx.builder.build_pointer_cast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
ndarray.as_base_value().into()
}
_ => {
let slot = ctx.builder.build_alloca(llvm_ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, llvm_pi8, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx
.builder
.build_int_compare(IntPredicate::EQ, llvm_i32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr =
ctx.builder.build_array_alloca(llvm_pi8, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr =
ctx.builder.build_bitcast(alloc_ptr, llvm_pi8, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
ctx.builder.build_load(slot, "rpc.result").unwrap()
}
};
Some(result)
}
fn rpc_codegen_callback_fn<'ctx>( fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
@ -541,6 +831,9 @@ fn rpc_codegen_callback_fn<'ctx>(
let ptr_type = int8.ptr_type(AddressSpace::default()); let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
// println!("obj: {:?}", obj);
println!("fun: {:?}", fun);
let service_id = int32.const_int(fun.1 .0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags // -- setup rpc tags
let mut tag = Vec::new(); let mut tag = Vec::new();
@ -663,63 +956,14 @@ fn rpc_codegen_callback_fn<'ctx>(
// reclaim stack space used by arguments // reclaim stack space used by arguments
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
// -- receive value: let result = format_rpc_ret(generator, ctx, fun.0.ret);
// T result = {
// void *ret_ptr = alloca(sizeof(T));
// void *ptr = ret_ptr;
// loop: int size = rpc_recv(ptr);
// // Non-zero: Provide `size` bytes of extra storage for variable-length data.
// if(size) { ptr = alloca(size); goto loop; }
// else *(T*)ret_ptr
// }
let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| {
ctx.module.add_function("rpc_recv", int32.fn_type(&[ptr_type.into()], false), None)
});
if ctx.unifier.unioned(fun.0.ret, ctx.primitives.none) { if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
ctx.build_call_or_invoke(rpc_recv, &[ptr_type.const_null().into()], "rpc_recv"); // An RPC returning an NDArray would not touch here.
return Ok(None);
}
let prehead_bb = ctx.builder.get_insert_block().unwrap();
let current_function = prehead_bb.get_parent().unwrap();
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
let alloc_bb = ctx.ctx.append_basic_block(current_function, "rpc.continue");
let tail_bb = ctx.ctx.append_basic_block(current_function, "rpc.tail");
let ret_ty = ctx.get_llvm_abi_type(generator, fun.0.ret);
let need_load = !ret_ty.is_pointer_type();
let slot = ctx.builder.build_alloca(ret_ty, "rpc.ret.slot").unwrap();
let slotgen = ctx.builder.build_bitcast(slot, ptr_type, "rpc.ret.ptr").unwrap();
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(head_bb);
let phi = ctx.builder.build_phi(ptr_type, "rpc.ptr").unwrap();
phi.add_incoming(&[(&slotgen, prehead_bb)]);
let alloc_size = ctx
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx
.builder
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
ctx.builder.position_at_end(alloc_bb);
let alloc_ptr = ctx.builder.build_array_alloca(ptr_type, alloc_size, "rpc.alloc").unwrap();
let alloc_ptr = ctx.builder.build_bitcast(alloc_ptr, ptr_type, "rpc.alloc.ptr").unwrap();
phi.add_incoming(&[(&alloc_ptr, alloc_bb)]);
ctx.builder.build_unconditional_branch(head_bb).unwrap();
ctx.builder.position_at_end(tail_bb);
let result = ctx.builder.build_load(slot, "rpc.result").unwrap();
if need_load {
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
} }
Ok(Some(result))
Ok(result)
} }
pub fn attributes_writeback( pub fn attributes_writeback(
@ -810,6 +1054,7 @@ pub fn attributes_writeback(
.collect(), .collect(),
ret: ctx.primitives.none, ret: ctx.primitives.none,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
}; };
let args: Vec<_> = let args: Vec<_> =
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();

View File

@ -191,14 +191,30 @@ impl Nac3 {
}) })
.unwrap() .unwrap()
}); });
body.retain(|stmt| { body.retain_mut(|stmt| {
if let StmtKind::FunctionDef { ref decorator_list, .. } = stmt.node { if let StmtKind::FunctionDef { ref mut decorator_list, .. } = stmt.node {
decorator_list.iter().any(|decorator| { decorator_list.iter_mut().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "kernel" id.to_string() == "kernel"
|| id.to_string() == "portable" || id.to_string() == "portable"
|| id.to_string() == "rpc" || id.to_string() == "rpc"
} else if let ExprKind::Call { func, .. } = &decorator.node {
// decorators with flags (e.g. rpc async) have Call for the node;
// this is to remove the middle
if let ExprKind::Name { id, .. } = func.node {
if id.to_string() == "rpc" {
println!("found rpc: {:?}", func);
println!("decorator node: {:?}", decorator.node);
decorator.node = func.clone().node;
true
} else { } else {
false
}
} else {
false
}
}
else {
false false
} }
}) })
@ -318,6 +334,7 @@ impl Nac3 {
}], }],
ret: primitives.none, ret: primitives.none,
vars: into_var_map([arg_ty]), vars: into_var_map([arg_ty]),
opts: vec![],
}, },
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_core_log(ctx, &obj, fun, &args, generator)?; gen_core_log(ctx, &obj, fun, &args, generator)?;
@ -348,6 +365,7 @@ impl Nac3 {
], ],
ret: primitives.none, ret: primitives.none,
vars: into_var_map([arg_ty]), vars: into_var_map([arg_ty]),
opts: vec![],
}, },
Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, obj, fun, args, generator| {
gen_rtio_log(ctx, &obj, fun, &args, generator)?; gen_rtio_log(ctx, &obj, fun, &args, generator)?;
@ -495,7 +513,7 @@ impl Nac3 {
class_name, stmt.location class_name, stmt.location
))); )));
} }
rpc_ids.push((Some((class_obj.clone(), *name)), def_id)); rpc_ids.push((Some((class_obj.clone(), *name)), def_id, Some(FuncFlags::Async)));
} }
} }
} }
@ -560,7 +578,7 @@ impl Nac3 {
let irrt = load_irrt(&context, resolver.as_ref()); let irrt = load_irrt(&context, resolver.as_ref());
let fun_signature = let fun_signature =
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new(), opts: vec![] };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = store.from_signature( let signature = store.from_signature(
@ -598,7 +616,7 @@ impl Nac3 {
{ {
let rpc_codegen = rpc_codegen_callback(); let rpc_codegen = rpc_codegen_callback();
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
for (class_data, id) in &rpc_ids { for (class_data, id, flags) in &rpc_ids {
let mut def = defs[id.0].write(); let mut def = defs[id.0].write();
match &mut *def { match &mut *def {
TopLevelDef::Function { codegen_callback, .. } => { TopLevelDef::Function { codegen_callback, .. } => {
@ -917,7 +935,7 @@ impl Nac3 {
let builtins = vec![ let builtins = vec![
( (
"now_mu".into(), "now_mu".into(),
FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new() }, FunSignature { args: vec![], ret: primitive.int64, vars: VarMap::new(), opts: vec![] },
Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| { Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
Ok(Some(time_fns.emit_now_mu(ctx))) Ok(Some(time_fns.emit_now_mu(ctx)))
}))), }))),
@ -933,6 +951,7 @@ impl Nac3 {
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
@ -953,6 +972,7 @@ impl Nac3 {
}], }],
ret: primitive.none, ret: primitive.none,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;

View File

@ -1250,11 +1250,13 @@ impl<'ctx> NDArrayType<'ctx> {
/// Returns the element type of this `ndarray` type. /// Returns the element type of this `ndarray` type.
#[must_use] #[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> { pub fn element_type(&self) -> AnyTypeEnum<'ctx> {
self.as_base_type() self.as_base_type()
.get_element_type() .get_element_type()
.into_struct_type() .into_struct_type()
.get_field_type_at_index(2) .get_field_type_at_index(2)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap() .unwrap()
} }
} }
@ -1404,7 +1406,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field. /// on the field.
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();

View File

@ -297,6 +297,7 @@ impl ConcreteTypeStore {
let ty = self.to_unifier_type(unifier, primitives, *cty, cache); let ty = self.to_unifier_type(unifier, primitives, *cty, cache);
TypeVar { id, ty } TypeVar { id, ty }
})), })),
opts: vec![],
}), }),
ConcreteTypeEnum::TLiteral { values, .. } => { ConcreteTypeEnum::TLiteral { values, .. } => {
TypeEnum::TLiteral { values: values.clone(), loc: None } TypeEnum::TLiteral { values: values.clone(), loc: None }

View File

@ -2385,7 +2385,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
}) })
.map(BasicValueEnum::into_int_value)?; .map(BasicValueEnum::into_int_value)?;
Ok(ctx.builder.build_not(cmp, "").unwrap()) Ok(ctx.builder.build_not(
generator.bool_to_i1(ctx, cmp),
"",
).unwrap())
}, },
|_, ctx| { |_, ctx| {
let bb = ctx.builder.get_insert_block().unwrap(); let bb = ctx.builder.get_insert_block().unwrap();

View File

@ -174,6 +174,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
} }
} }
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?; let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
// Perform i1 <-> i8 conversion as needed
let val = if ctx.unifier.unioned(target.custom.unwrap(), ctx.primitives.bool) {
generator.bool_to_i8(ctx, val.into_int_value()).into()
} else {
val
};
ctx.builder.build_store(ptr, val).unwrap(); ctx.builder.build_store(ptr, val).unwrap();
} }
}; };
@ -1665,6 +1673,23 @@ pub fn gen_return<G: CodeGenerator>(
} else { } else {
None None
}; };
// Remap boolean return type into i1
let value = value.map(|ret_val| {
// The "return type" of a sret function is in the first parameter
let expected_ty = if ctx.need_sret {
func.get_type().get_param_types()[0]
} else {
func.get_type().get_return_type().unwrap()
};
if matches!(expected_ty, BasicTypeEnum::IntType(ty) if ty.get_bit_width() == 1) {
generator.bool_to_i1(ctx, ret_val.into_int_value()).into()
} else {
ret_val
}
});
if let Some(return_target) = ctx.return_target { if let Some(return_target) = ctx.return_target {
if let Some(value) = value { if let Some(value) = value {
ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap(); ctx.builder.build_store(ctx.return_buffer.unwrap(), value).unwrap();
@ -1675,25 +1700,6 @@ pub fn gen_return<G: CodeGenerator>(
ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap(); ctx.builder.build_store(ctx.return_buffer.unwrap(), value.unwrap()).unwrap();
ctx.builder.build_return(None).unwrap(); ctx.builder.build_return(None).unwrap();
} else { } else {
// Remap boolean return type into i1
let value = value.map(|v| {
let expected_ty = func.get_type().get_return_type().unwrap();
let ret_val = v.as_basic_value_enum();
if expected_ty.is_int_type() && ret_val.is_int_value() {
let ret_type = expected_ty.into_int_type();
let ret_val = ret_val.into_int_value();
if ret_type.get_bit_width() == 1 && ret_val.get_type().get_bit_width() != 1 {
generator.bool_to_i1(ctx, ret_val)
} else {
ret_val
}
.into()
} else {
ret_val
}
});
let value = value.as_ref().map(|v| v as &dyn BasicValue); let value = value.as_ref().map(|v| v as &dyn BasicValue);
ctx.builder.build_return(value).unwrap(); ctx.builder.build_return(value).unwrap();
} }

View File

@ -124,6 +124,7 @@ fn test_primitives() {
], ],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
}; };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
@ -273,6 +274,7 @@ fn test_simple_call() {
}], }],
ret: primitives.int32, ret: primitives.int32,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
}; };
let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone())); let fun_ty = unifier.add_ty(TypeEnum::TFunc(signature.clone()));
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();

View File

@ -77,6 +77,7 @@ pub fn get_exn_constructor(
args: exn_cons_args, args: exn_cons_args,
ret: exn_type, ret: exn_type,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
})); }));
let fun_def = TopLevelDef::Function { let fun_def = TopLevelDef::Function {
name: format!("{name}.__init__"), name: format!("{name}.__init__"),
@ -137,6 +138,7 @@ fn create_fn_by_codegen(
.collect(), .collect(),
ret: ret_ty, ret: ret_ty,
vars: var_map.clone(), vars: var_map.clone(),
opts: vec![],
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -670,6 +672,7 @@ impl<'a> BuiltinBuilder<'a> {
], ],
ret: range, ret: range,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
})) }))
}; };
@ -925,6 +928,7 @@ impl<'a> BuiltinBuilder<'a> {
}], }],
ret: self.primitives.option, ret: self.primitives.option,
vars: into_var_map([self.option_tvar]), vars: into_var_map([self.option_tvar]),
opts: vec![],
})), })),
var_id: vec![self.option_tvar.id], var_id: vec![self.option_tvar.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1060,6 +1064,7 @@ impl<'a> BuiltinBuilder<'a> {
}], }],
ret: self.num_or_ndarray_ty.ty, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
opts: vec![],
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1297,6 +1302,7 @@ impl<'a> BuiltinBuilder<'a> {
], ],
ret: ndarray, ret: ndarray,
vars: into_var_map([tv]), vars: into_var_map([tv]),
opts: vec![],
})), })),
var_id: vec![tv.id], var_id: vec![tv.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1356,6 +1362,7 @@ impl<'a> BuiltinBuilder<'a> {
], ],
ret: self.ndarray_float_2d, ret: self.ndarray_float_2d,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1403,6 +1410,7 @@ impl<'a> BuiltinBuilder<'a> {
}], }],
ret: str, ret: str,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1479,6 +1487,7 @@ impl<'a> BuiltinBuilder<'a> {
}], }],
ret: self.primitives.int32, ret: self.primitives.int32,
vars: into_var_map([arg_tvar]), vars: into_var_map([arg_tvar]),
opts: vec![],
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1520,6 +1529,7 @@ impl<'a> BuiltinBuilder<'a> {
], ],
ret: self.num_ty.ty, ret: self.num_ty.ty,
vars: self.num_var_map.clone(), vars: self.num_var_map.clone(),
opts: vec![],
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1607,6 +1617,7 @@ impl<'a> BuiltinBuilder<'a> {
.collect(), .collect(),
ret: ret_ty.ty, ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]), vars: into_var_map([x1_ty, x2_ty, ret_ty]),
opts: vec![],
})), })),
var_id: vec![x1_ty.id, x2_ty.id], var_id: vec![x1_ty.id, x2_ty.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1648,6 +1659,7 @@ impl<'a> BuiltinBuilder<'a> {
}], }],
ret: self.num_or_ndarray_ty.ty, ret: self.num_or_ndarray_ty.ty,
vars: self.num_or_ndarray_var_map.clone(), vars: self.num_or_ndarray_var_map.clone(),
opts: vec![],
})), })),
var_id: Vec::default(), var_id: Vec::default(),
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
@ -1842,6 +1854,7 @@ impl<'a> BuiltinBuilder<'a> {
.collect(), .collect(),
ret: ret_ty.ty, ret: ret_ty.ty,
vars: into_var_map([x1_ty, x2_ty, ret_ty]), vars: into_var_map([x1_ty, x2_ty, ret_ty]),
opts: vec![],
})), })),
var_id: vec![ret_ty.id], var_id: vec![ret_ty.id],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),

View File

@ -1122,6 +1122,7 @@ impl TopLevelComposer {
args: arg_types, args: arg_types,
ret: return_ty, ret: return_ty,
vars: function_var_map, vars: function_var_map,
opts: vec![],
})); }));
unifier.unify(*dummy_ty, function_ty).map_err(|e| { unifier.unify(*dummy_ty, function_ty).map_err(|e| {
HashSet::from([e.at(Some(function_ast.location)).to_display(unifier).to_string()]) HashSet::from([e.at(Some(function_ast.location)).to_display(unifier).to_string()])
@ -1386,6 +1387,7 @@ impl TopLevelComposer {
args: arg_types, args: arg_types,
ret: ret_type, ret: ret_type,
vars: method_var_map, vars: method_var_map,
opts: vec![],
})); }));
// unify now since function type is not in type annotation define // unify now since function type is not in type annotation define
@ -1673,7 +1675,7 @@ impl TopLevelComposer {
// they may be changed with our use of placeholders // they may be changed with our use of placeholders
for (def, _) in definition_ast_list.iter().skip(self.builtin_num) { for (def, _) in definition_ast_list.iter().skip(self.builtin_num) {
if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() { if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() {
if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = if let TypeEnum::TFunc(FunSignature { args, ret, vars, opts }) =
unifier.get_ty(*signature).as_ref() unifier.get_ty(*signature).as_ref()
{ {
let new_var_ids = vars let new_var_ids = vars
@ -1692,6 +1694,7 @@ impl TopLevelComposer {
.zip(vars.values()) .zip(vars.values())
.map(|(id, v)| (*id, *v)) .map(|(id, v)| (*id, *v))
.collect(), .collect(),
opts: opts.clone(),
}; };
unifier unifier
.unification_table .unification_table
@ -1760,6 +1763,7 @@ impl TopLevelComposer {
], ],
ret: self_type, ret: self_type,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
})); }));
let cons_fun = TopLevelDef::Function { let cons_fun = TopLevelDef::Function {
name: format!("{}.{}", class_name, "__init__"), name: format!("{}.{}", class_name, "__init__"),
@ -1807,6 +1811,7 @@ impl TopLevelComposer {
args: contor_args, args: contor_args,
ret: self_type, ret: self_type,
vars: contor_type_vars, vars: contor_type_vars,
opts: vec![],
})); }));
unifier.unify(constructor.unwrap(), contor_type).map_err(|e| { unifier.unify(constructor.unwrap(), contor_type).map_err(|e| {
HashSet::from([e HashSet::from([e
@ -1894,7 +1899,7 @@ impl TopLevelComposer {
} = &mut *function_def } = &mut *function_def
{ {
let signature_ty_enum = unifier.get_ty(*signature); let signature_ty_enum = unifier.get_ty(*signature);
let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref() let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = signature_ty_enum.as_ref()
else { else {
unreachable!("must be typeenum::tfunc") unreachable!("must be typeenum::tfunc")
}; };

View File

@ -461,11 +461,13 @@ impl TopLevelComposer {
args: vec![], args: vec![],
ret: bool, ret: bool,
vars: into_var_map([option_type_var]), vars: into_var_map([option_type_var]),
opts: vec![],
})); }));
let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: option_type_var.ty, ret: option_type_var.ty,
vars: into_var_map([option_type_var]), vars: into_var_map([option_type_var]),
opts: vec![],
})); }));
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
@ -500,6 +502,7 @@ impl TopLevelComposer {
args: vec![], args: vec![],
ret: ndarray_copy_fun_ret_ty.ty, ret: ndarray_copy_fun_ret_ty.ty,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
opts: vec![],
})); }));
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg { args: vec![FuncArg {
@ -510,6 +513,7 @@ impl TopLevelComposer {
}], }],
ret: none, ret: none,
vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), vars: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
opts: vec![],
})); }));
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),

View File

@ -199,6 +199,7 @@ pub fn impl_binop(
name: "other".into(), name: "other".into(),
is_vararg: false, is_vararg: false,
}], }],
opts: vec![],
})), })),
false, false,
) )
@ -219,6 +220,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
ret: ret_ty, ret: ret_ty,
vars: VarMap::new(), vars: VarMap::new(),
args: vec![], args: vec![],
opts: vec![],
})), })),
false, false,
), ),
@ -264,6 +266,7 @@ pub fn impl_cmpop(
name: "other".into(), name: "other".into(),
is_vararg: false, is_vararg: false,
}], }],
opts: vec![],
})), })),
false, false,
), ),

View File

@ -464,12 +464,14 @@ impl<'a> Fold<()> for Inferencer<'a> {
|var| var.custom.unwrap(), |var| var.custom.unwrap(),
), ),
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
}); });
let enter = self.unifier.add_ty(enter); let enter = self.unifier.add_ty(enter);
let exit = TypeEnum::TFunc(FunSignature { let exit = TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
ret: self.unifier.get_dummy_var().ty, ret: self.unifier.get_dummy_var().ty,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
}); });
let exit = self.unifier.add_ty(exit); let exit = self.unifier.add_ty(exit);
let mut fields = HashMap::new(); let mut fields = HashMap::new();
@ -766,6 +768,7 @@ impl<'a> Inferencer<'a> {
.collect(), .collect(),
ret, ret,
vars: VarMap::default(), vars: VarMap::default(),
opts: vec![],
}; };
let body = new_context.fold_expr(body)?; let body = new_context.fold_expr(body)?;
new_context.unify(fun.ret, body.custom.unwrap(), &location)?; new_context.unify(fun.ret, body.custom.unwrap(), &location)?;
@ -1107,6 +1110,7 @@ impl<'a> Inferencer<'a> {
}], }],
ret: ret_ty, ret: ret_ty,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1166,6 +1170,7 @@ impl<'a> Inferencer<'a> {
}], }],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1214,6 +1219,7 @@ impl<'a> Inferencer<'a> {
], ],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1253,6 +1259,7 @@ impl<'a> Inferencer<'a> {
}], }],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1365,6 +1372,7 @@ impl<'a> Inferencer<'a> {
], ],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1445,6 +1453,7 @@ impl<'a> Inferencer<'a> {
}], }],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1487,6 +1496,7 @@ impl<'a> Inferencer<'a> {
}], }],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1532,6 +1542,7 @@ impl<'a> Inferencer<'a> {
], ],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1586,6 +1597,7 @@ impl<'a> Inferencer<'a> {
], ],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {
@ -1654,6 +1666,7 @@ impl<'a> Inferencer<'a> {
], ],
ret, ret,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
return Ok(Some(Located { return Ok(Some(Located {

View File

@ -91,6 +91,7 @@ impl TestEnvironment {
}], }],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
@ -237,6 +238,7 @@ impl TestEnvironment {
}], }],
ret: int32, ret: int32,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
@ -386,6 +388,7 @@ impl TestEnvironment {
args: vec![], args: vec![],
ret: foo_ty, ret: foo_ty,
vars: into_var_map([tvar]), vars: into_var_map([tvar]),
opts: vec![],
})), })),
); );
@ -393,6 +396,7 @@ impl TestEnvironment {
args: vec![], args: vec![],
ret: int32, ret: int32,
vars: IndexMap::default(), vars: IndexMap::default(),
opts: vec![],
})); }));
let bar = unifier.add_ty(TypeEnum::TObj { let bar = unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(defs + 2), obj_id: DefinitionId(defs + 2),
@ -420,6 +424,7 @@ impl TestEnvironment {
args: vec![], args: vec![],
ret: bar, ret: bar,
vars: IndexMap::default(), vars: IndexMap::default(),
opts: vec![],
})), })),
); );
@ -449,6 +454,7 @@ impl TestEnvironment {
args: vec![], args: vec![],
ret: bar2, ret: bar2,
vars: IndexMap::default(), vars: IndexMap::default(),
opts: vec![],
})), })),
); );
let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into(); let class_names: HashMap<_, _> = [("Bar".into(), bar), ("Bar2".into(), bar2)].into();

View File

@ -123,11 +123,17 @@ impl FuncArg {
} }
} }
#[derive(Debug, Clone)]
pub enum FunOption {
Async,
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct FunSignature { pub struct FunSignature {
pub args: Vec<FuncArg>, pub args: Vec<FuncArg>,
pub ret: Type, pub ret: Type,
pub vars: VarMap, pub vars: VarMap,
pub opts: Vec<FunOption>,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -1565,7 +1571,7 @@ impl Unifier {
None None
} }
} }
TypeEnum::TFunc(FunSignature { args, ret, vars: params }) => { TypeEnum::TFunc(FunSignature { args, ret, vars: params, opts }) => {
let new_params = self.subst_map(params, mapping, cache); let new_params = self.subst_map(params, mapping, cache);
let new_ret = self.subst_impl(*ret, mapping, cache); let new_ret = self.subst_impl(*ret, mapping, cache);
let mut new_args = Cow::from(args); let mut new_args = Cow::from(args);
@ -1580,7 +1586,8 @@ impl Unifier {
let params = new_params.unwrap_or_else(|| params.clone()); let params = new_params.unwrap_or_else(|| params.clone());
let ret = new_ret.unwrap_or(*ret); let ret = new_ret.unwrap_or(*ret);
let args = new_args.into_owned(); let args = new_args.into_owned();
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params }))) let opts = opts.clone();
Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, vars: params, opts })))
} else { } else {
None None
} }

View File

@ -389,6 +389,7 @@ fn test_virtual() {
args: vec![], args: vec![],
ret: int, ret: int,
vars: VarMap::new(), vars: VarMap::new(),
opts: vec![],
})); }));
let bar = env.unifier.add_ty(TypeEnum::TObj { let bar = env.unifier.add_ty(TypeEnum::TObj {
obj_id: DefinitionId(5), obj_id: DefinitionId(5),

View File

@ -360,7 +360,7 @@ fn main() {
} }
} }
let signature = FunSignature { args: vec![], ret: primitive.int32, vars: VarMap::new() }; let signature = FunSignature { args: vec![], ret: primitive.int32, vars: VarMap::new(), opts: vec![] };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache); let signature = store.from_signature(&mut composer.unifier, &primitive, &signature, &mut cache);