forked from M-Labs/nac3
Compare commits
4 Commits
master
...
fix/unitte
Author | SHA1 | Date | |
---|---|---|---|
|
6b02ec2a07 | ||
|
45ac109c03 | ||
|
b03a5646ee | ||
637e7db70f |
@ -230,10 +230,21 @@ impl CodeGenerator for ArtiqCodeGenerator<'_> {
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
kwargs: Option<HashMap<StrRef, ValueEnum<'ctx>>>, // New parameter for keyword arguments
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
let result = gen_call(self, ctx, obj, fun, params)?;
|
||||
let mut combined_params = params;
|
||||
|
||||
// If keyword arguments are provided, map them to the function signature
|
||||
if let Some(kwargs) = kwargs {
|
||||
for arg in &fun.0.args {
|
||||
if let Some(value) = kwargs.get(&arg.name) {
|
||||
combined_params.push((Some(arg.name), value.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = gen_call(self, ctx, obj, fun, combined_params)?;
|
||||
|
||||
// Deep parallel emits timeline end-update/timeline-reset after each function call
|
||||
if self.parallel_mode == ParallelMode::Deep {
|
||||
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
|
||||
self.timeline_reset_start(ctx)?;
|
||||
@ -829,6 +840,7 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
kwargs: Option<HashMap<StrRef, ValueEnum<'ctx>>>, // New parameter for keyword arguments
|
||||
generator: &mut dyn CodeGenerator,
|
||||
is_async: bool,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
@ -836,147 +848,72 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_type = ctx.get_size_type();
|
||||
let ptr_type = int8.ptr_type(AddressSpace::default());
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
||||
// -- setup rpc tags
|
||||
let mut tag = Vec::new();
|
||||
if obj.is_some() {
|
||||
tag.push(b'O');
|
||||
|
||||
// Handle both positional and keyword arguments
|
||||
let mut mapping: HashMap<StrRef, ValueEnum<'ctx>> = HashMap::new();
|
||||
|
||||
// Add positional arguments first
|
||||
let mut keys = fun.0.args.clone();
|
||||
for (key, value) in args {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
}
|
||||
for arg in &fun.0.args {
|
||||
gen_rpc_tag(ctx, arg.ty, false, &mut tag)?;
|
||||
|
||||
// Add keyword arguments if provided
|
||||
if let Some(kwargs) = kwargs {
|
||||
for (key, value) in kwargs {
|
||||
mapping.insert(key, value);
|
||||
}
|
||||
}
|
||||
tag.push(b':');
|
||||
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
tag.hash(&mut hasher);
|
||||
let hash = format!("{}", hasher.finish());
|
||||
let mut real_params = fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
mapping
|
||||
.remove(&arg.name)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, arg.ty)
|
||||
.map(|llvm_val| (llvm_val, arg.ty))
|
||||
})
|
||||
.collect::<Result<Vec<(_, _)>, _>>()?;
|
||||
|
||||
let maybe_existing = ctx.module.get_global(&hash);
|
||||
let tag_ptr = if let Some(gv) = maybe_existing {
|
||||
gv.as_pointer_value()
|
||||
} else {
|
||||
let tag_len = tag.len();
|
||||
let arr_ty = int8.array_type(tag_len as u32);
|
||||
let tag_const = int8
|
||||
.const_array(&tag.iter().map(|&b| int8.const_int(u64::from(b), false)).collect_vec());
|
||||
let arr_gv = ctx.module.add_global(arr_ty, None, &format!("{hash}.arr"));
|
||||
arr_gv.set_linkage(Linkage::Private);
|
||||
arr_gv.set_initializer(&tag_const);
|
||||
|
||||
let st = ctx.ctx.const_struct(
|
||||
&[
|
||||
arr_gv.as_pointer_value().const_cast(ptr_type).into(),
|
||||
size_type.const_int(tag_len as u64, false).into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
let st_gv = ctx.module.add_global(st.get_type(), None, &hash);
|
||||
st_gv.set_linkage(Linkage::Private);
|
||||
st_gv.set_initializer(&st);
|
||||
st_gv.as_pointer_value()
|
||||
};
|
||||
|
||||
// -- rpc args handling
|
||||
let n_params = fun.0.args.len();
|
||||
let mut param_map: Vec<Option<ValueEnum<'ctx>>> = vec![None; n_params];
|
||||
|
||||
let mut pos_index = 0usize;
|
||||
|
||||
if let Some((_obj_ty, obj_val)) = obj {
|
||||
param_map[0] = Some(obj_val);
|
||||
pos_index = 1;
|
||||
}
|
||||
for (maybe_key, val_enum) in args {
|
||||
if let Some(kw_name) = maybe_key {
|
||||
let param_pos = fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.position(|arg| arg.name == kw_name)
|
||||
.ok_or_else(|| format!("Unknown keyword argument '{kw_name}'"))?;
|
||||
|
||||
if param_map[param_pos].is_some() {
|
||||
return Err(format!("Multiple values for argument '{kw_name}'"));
|
||||
}
|
||||
param_map[param_pos] = Some(val_enum);
|
||||
// Existing logic for generating the RPC call remains largely unchanged
|
||||
if let Some(obj) = obj {
|
||||
if let ValueEnum::Static(obj_val) = obj.1 {
|
||||
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
|
||||
} else {
|
||||
while pos_index < n_params && param_map[pos_index].is_some() {
|
||||
pos_index += 1;
|
||||
}
|
||||
if pos_index >= n_params {
|
||||
return Err("Too many positional arguments given to function.".to_string());
|
||||
}
|
||||
param_map[pos_index] = Some(val_enum);
|
||||
pos_index += 1;
|
||||
panic!("only host object is allowed");
|
||||
}
|
||||
}
|
||||
for (i, param) in fun.0.args.iter().enumerate() {
|
||||
if param_map[i].is_none() {
|
||||
if let Some(default_expr) = ¶m.default_value {
|
||||
let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into();
|
||||
param_map[i] = Some(default_val);
|
||||
} else {
|
||||
return Err(format!("Missing required argument '{}'", param.name));
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut real_params = Vec::with_capacity(n_params);
|
||||
for (i, param_spec) in fun.0.args.iter().enumerate() {
|
||||
let some_valenum = param_map[i].take().unwrap();
|
||||
let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?;
|
||||
real_params.push((llvm_val, param_spec.ty));
|
||||
}
|
||||
|
||||
let arg_count = real_params.len() as u64;
|
||||
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
|
||||
|
||||
let i32_ty = ctx.ctx.i32_type();
|
||||
let arg_array = ctx
|
||||
.builder
|
||||
.build_array_alloca(ptr_type, i32_ty.const_int(arg_count, false), "rpc.arg_array")
|
||||
.unwrap();
|
||||
|
||||
for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
|
||||
let arg_slot_ptr = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
arg_array,
|
||||
&[i32_ty.const_int(i as u64, false)],
|
||||
&format!("rpc.arg_slot_{i}"),
|
||||
// Generate the RPC call as before, but with the updated `real_params`
|
||||
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
|
||||
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
|
||||
let arg_ptr = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
args_ptr,
|
||||
&[int32.const_int(i as u64, false)],
|
||||
&format!("rpc.arg{i}"),
|
||||
)
|
||||
}
|
||||
.unwrap();
|
||||
let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
|
||||
ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
|
||||
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
||||
}
|
||||
|
||||
// call
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
if is_async { "rpc_send_async" } else { "rpc_send" },
|
||||
None,
|
||||
&[service_id.into(), tag_ptr.into(), arg_array.into()],
|
||||
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
||||
Some("rpc.send"),
|
||||
None,
|
||||
);
|
||||
|
||||
// reclaim stack space used by arguments
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
if is_async {
|
||||
// async RPCs do not return any values
|
||||
Ok(None)
|
||||
} else {
|
||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||
|
||||
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||
// An RPC returning an NDArray would not touch here.
|
||||
call_stackrestore(ctx, stackptr);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
Ok(format_rpc_ret(generator, ctx, fun.0.ret))
|
||||
}
|
||||
|
||||
pub fn attributes_writeback<'ctx>(
|
||||
|
Loading…
x
Reference in New Issue
Block a user