[WIP] Update based on feedback

This commit is contained in:
ram 2025-02-09 16:42:52 +00:00
parent c083245c28
commit c5aa042287
2 changed files with 35 additions and 48 deletions

View File

@ -19,23 +19,19 @@ class RpcKwargTest:
def run(self):
#1) All positional => a=1, b=2, c=3 -> total=6
s1 = sum_3(1, 2, 3)
if s1 != 6:
raise ValueError("sum_3(1,2,3) gave the wrong result.")
assert s1 == 6
#2) Use the default b=10, c=20 => a=5 => total=35
s2 = sum_3(5)
if s2 != 35:
raise ValueError("sum_3(5) gave the wrong result.")
assert s2 == 35
#3) a=1 (positional), b=100 (keyword), omit c => c=20 => total=121
s3 = sum_3(1, b=100)
if s3 != 121:
raise ValueError("sum_3(1, b=100) gave the wrong result.")
assert s3 == 121
#4) a=2, c=300 => b=10 (default) => total=312
s4 = sum_3(a=2, c=300)
if s4 != 312:
raise ValueError("sum_3(a=2, c=300) gave the wrong result.")
assert s4 == 312
if __name__ == "__main__":
RpcKwargTest().run()

View File

@ -1,5 +1,5 @@
use std::{
collections::{hash_map::DefaultHasher, HashMap},
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
iter::once,
mem,
@ -79,7 +79,8 @@ pub struct ArtiqCodeGenerator<'a> {
/// The [`ParallelMode`] of the current parallel context.
///
/// The current parallel context refers to the nearest `with` statement, which is used to determine when and how the timeline should be updated.
/// The current parallel context refers to the nearest `with` statement,
/// which is used to determine when and how the timeline should be updated.
parallel_mode: ParallelMode,
}
@ -372,12 +373,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
fn gen_rpc_tag(
ctx: &mut CodeGenContext<'_, '_>,
ty: Type,
is_kwarg: bool, // Add this parameter
is_kwarg: bool,
buffer: &mut Vec<u8>,
) -> Result<(), String> {
// Add kwarg marker if needed
if is_kwarg {
buffer.push(b'k'); // 'k' for keyword argument
buffer.push(b'k');
}
use nac3core::typecheck::typedef::TypeEnum::*;
@ -408,14 +408,14 @@ fn gen_rpc_tag(
buffer.push(b't');
buffer.push(ty.len() as u8);
for ty in ty {
gen_rpc_tag(ctx, *ty, false, buffer)?; // Pass false for is_kwarg
gen_rpc_tag(ctx, *ty, false, buffer)?;
}
}
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
let ty = iter_type_vars(params).next().unwrap().ty;
buffer.push(b'l');
gen_rpc_tag(ctx, ty, false, buffer)?; // Pass false for is_kwarg
gen_rpc_tag(ctx, ty, false, buffer)?;
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
@ -796,7 +796,7 @@ pub fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
mut args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
is_async: bool,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
@ -820,8 +820,6 @@ pub fn rpc_codegen_callback_fn<'ctx>(
tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher);
let hash = format!("rpc_tag_{}", hasher.finish());
@ -832,12 +830,9 @@ pub fn rpc_codegen_callback_fn<'ctx>(
} else {
let tag_len = tag.len();
let arr_ty = int8.array_type(tag_len as u32);
let tag_const = int8.const_array(
&tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::<Vec<_>>(),
);
let arr_gv = ctx
.module
.add_global(arr_ty, None, &format!("{}.arr", hash));
let tag_const = int8
.const_array(&tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::<Vec<_>>());
let arr_gv = ctx.module.add_global(arr_ty, None, &format!("{}.arr", hash));
arr_gv.set_linkage(Linkage::Private);
arr_gv.set_initializer(&tag_const);
@ -860,13 +855,14 @@ pub fn rpc_codegen_callback_fn<'ctx>(
let mut pos_index = 0usize;
if let Some((obj_ty, obj_val)) = obj {
if let Some((_obj_ty, obj_val)) = obj {
param_map[0] = Some(obj_val);
pos_index = 1;
}
for (maybe_key, val_enum) in args.drain(..) {
for (maybe_key, val_enum) in args {
if let Some(kw_name) = maybe_key {
let param_pos = fun.0
let param_pos = fun
.0
.args
.iter()
.position(|arg| arg.name == kw_name)
@ -911,23 +907,18 @@ pub fn rpc_codegen_callback_fn<'ctx>(
let i32_ty = ctx.ctx.i32_type();
let arg_array = ctx
.builder
.build_array_alloca(
ptr_type,
i32_ty.const_int(arg_count, false),
"rpc.arg_array",
)
.build_array_alloca(ptr_type, i32_ty.const_int(arg_count, false), "rpc.arg_array")
.unwrap();
for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
let arg_slot_ptr = unsafe {
ctx.builder.build_gep(
arg_array,
&[
i32_ty.const_int(i as u64, false),
],
&[i32_ty.const_int(i as u64, false)],
&format!("rpc.arg_slot_{}", i),
)
}.unwrap();
}
.unwrap();
let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
}
@ -937,7 +928,11 @@ pub fn rpc_codegen_callback_fn<'ctx>(
ctx.module.add_function(
"rpc_send_async",
ctx.ctx.void_type().fn_type(
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
&[
int32.into(),
tag_ptr_type.into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
@ -946,11 +941,7 @@ pub fn rpc_codegen_callback_fn<'ctx>(
ctx.builder
.build_call(
rpc_send_async,
&[
service_id.into(),
tag_ptr.into(),
arg_array.into(),
],
&[service_id.into(), tag_ptr.into(), arg_array.into()],
"rpc.send_async",
)
.unwrap();
@ -961,7 +952,11 @@ pub fn rpc_codegen_callback_fn<'ctx>(
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
&[
int32.into(),
tag_ptr_type.into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false,
),
None,
@ -970,11 +965,7 @@ pub fn rpc_codegen_callback_fn<'ctx>(
ctx.builder
.build_call(
rpc_send,
&[
service_id.into(),
tag_ptr.into(),
arg_array.into(),
],
&[service_id.into(), tag_ptr.into(), arg_array.into()],
"rpc.send",
)
.unwrap();