[WIP] Update based on feedback
This commit is contained in:
parent
c083245c28
commit
c5aa042287
@ -19,23 +19,19 @@ class RpcKwargTest:
|
||||
def run(self):
|
||||
#1) All positional => a=1, b=2, c=3 -> total=6
|
||||
s1 = sum_3(1, 2, 3)
|
||||
if s1 != 6:
|
||||
raise ValueError("sum_3(1,2,3) gave the wrong result.")
|
||||
assert s1 == 6
|
||||
|
||||
#2) Use the default b=10, c=20 => a=5 => total=35
|
||||
s2 = sum_3(5)
|
||||
if s2 != 35:
|
||||
raise ValueError("sum_3(5) gave the wrong result.")
|
||||
assert s2 == 35
|
||||
|
||||
#3) a=1 (positional), b=100 (keyword), omit c => c=20 => total=121
|
||||
s3 = sum_3(1, b=100)
|
||||
if s3 != 121:
|
||||
raise ValueError("sum_3(1, b=100) gave the wrong result.")
|
||||
assert s3 == 121
|
||||
|
||||
#4) a=2, c=300 => b=10 (default) => total=312
|
||||
s4 = sum_3(a=2, c=300)
|
||||
if s4 != 312:
|
||||
raise ValueError("sum_3(a=2, c=300) gave the wrong result.")
|
||||
assert s4 == 312
|
||||
|
||||
if __name__ == "__main__":
|
||||
RpcKwargTest().run()
|
@ -1,5 +1,5 @@
|
||||
use std::{
|
||||
collections::{hash_map::DefaultHasher, HashMap},
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
iter::once,
|
||||
mem,
|
||||
@ -79,7 +79,8 @@ pub struct ArtiqCodeGenerator<'a> {
|
||||
|
||||
/// The [`ParallelMode`] of the current parallel context.
|
||||
///
|
||||
/// The current parallel context refers to the nearest `with` statement, which is used to determine when and how the timeline should be updated.
|
||||
/// The current parallel context refers to the nearest `with` statement,
|
||||
/// which is used to determine when and how the timeline should be updated.
|
||||
parallel_mode: ParallelMode,
|
||||
}
|
||||
|
||||
@ -372,12 +373,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
fn gen_rpc_tag(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
ty: Type,
|
||||
is_kwarg: bool, // Add this parameter
|
||||
is_kwarg: bool,
|
||||
buffer: &mut Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
// Add kwarg marker if needed
|
||||
if is_kwarg {
|
||||
buffer.push(b'k'); // 'k' for keyword argument
|
||||
buffer.push(b'k');
|
||||
}
|
||||
|
||||
use nac3core::typecheck::typedef::TypeEnum::*;
|
||||
@ -408,14 +408,14 @@ fn gen_rpc_tag(
|
||||
buffer.push(b't');
|
||||
buffer.push(ty.len() as u8);
|
||||
for ty in ty {
|
||||
gen_rpc_tag(ctx, *ty, false, buffer)?; // Pass false for is_kwarg
|
||||
gen_rpc_tag(ctx, *ty, false, buffer)?;
|
||||
}
|
||||
}
|
||||
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
||||
let ty = iter_type_vars(params).next().unwrap().ty;
|
||||
|
||||
buffer.push(b'l');
|
||||
gen_rpc_tag(ctx, ty, false, buffer)?; // Pass false for is_kwarg
|
||||
gen_rpc_tag(ctx, ty, false, buffer)?;
|
||||
}
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
@ -796,7 +796,7 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
mut args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
is_async: bool,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
@ -820,8 +820,6 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
tag.push(b':');
|
||||
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
tag.hash(&mut hasher);
|
||||
let hash = format!("rpc_tag_{}", hasher.finish());
|
||||
@ -832,12 +830,9 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
} else {
|
||||
let tag_len = tag.len();
|
||||
let arr_ty = int8.array_type(tag_len as u32);
|
||||
let tag_const = int8.const_array(
|
||||
&tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::<Vec<_>>(),
|
||||
);
|
||||
let arr_gv = ctx
|
||||
.module
|
||||
.add_global(arr_ty, None, &format!("{}.arr", hash));
|
||||
let tag_const = int8
|
||||
.const_array(&tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::<Vec<_>>());
|
||||
let arr_gv = ctx.module.add_global(arr_ty, None, &format!("{}.arr", hash));
|
||||
arr_gv.set_linkage(Linkage::Private);
|
||||
arr_gv.set_initializer(&tag_const);
|
||||
|
||||
@ -860,13 +855,14 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
|
||||
let mut pos_index = 0usize;
|
||||
|
||||
if let Some((obj_ty, obj_val)) = obj {
|
||||
if let Some((_obj_ty, obj_val)) = obj {
|
||||
param_map[0] = Some(obj_val);
|
||||
pos_index = 1;
|
||||
}
|
||||
for (maybe_key, val_enum) in args.drain(..) {
|
||||
for (maybe_key, val_enum) in args {
|
||||
if let Some(kw_name) = maybe_key {
|
||||
let param_pos = fun.0
|
||||
let param_pos = fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.position(|arg| arg.name == kw_name)
|
||||
@ -911,23 +907,18 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
let i32_ty = ctx.ctx.i32_type();
|
||||
let arg_array = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
ptr_type,
|
||||
i32_ty.const_int(arg_count, false),
|
||||
"rpc.arg_array",
|
||||
)
|
||||
.build_array_alloca(ptr_type, i32_ty.const_int(arg_count, false), "rpc.arg_array")
|
||||
.unwrap();
|
||||
|
||||
for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
|
||||
let arg_slot_ptr = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
arg_array,
|
||||
&[
|
||||
i32_ty.const_int(i as u64, false),
|
||||
],
|
||||
&[i32_ty.const_int(i as u64, false)],
|
||||
&format!("rpc.arg_slot_{}", i),
|
||||
)
|
||||
}.unwrap();
|
||||
}
|
||||
.unwrap();
|
||||
let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
|
||||
ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
|
||||
}
|
||||
@ -937,7 +928,11 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx.module.add_function(
|
||||
"rpc_send_async",
|
||||
ctx.ctx.void_type().fn_type(
|
||||
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
|
||||
&[
|
||||
int32.into(),
|
||||
tag_ptr_type.into(),
|
||||
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
@ -946,11 +941,7 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx.builder
|
||||
.build_call(
|
||||
rpc_send_async,
|
||||
&[
|
||||
service_id.into(),
|
||||
tag_ptr.into(),
|
||||
arg_array.into(),
|
||||
],
|
||||
&[service_id.into(), tag_ptr.into(), arg_array.into()],
|
||||
"rpc.send_async",
|
||||
)
|
||||
.unwrap();
|
||||
@ -961,7 +952,11 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx.module.add_function(
|
||||
"rpc_send",
|
||||
ctx.ctx.void_type().fn_type(
|
||||
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
|
||||
&[
|
||||
int32.into(),
|
||||
tag_ptr_type.into(),
|
||||
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
@ -970,11 +965,7 @@ pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx.builder
|
||||
.build_call(
|
||||
rpc_send,
|
||||
&[
|
||||
service_id.into(),
|
||||
tag_ptr.into(),
|
||||
arg_array.into(),
|
||||
],
|
||||
&[service_id.into(), tag_ptr.into(), arg_array.into()],
|
||||
"rpc.send",
|
||||
)
|
||||
.unwrap();
|
||||
|
Loading…
Reference in New Issue
Block a user