forked from M-Labs/nac3
Compare commits
1 Commits
master
...
feature/rp
Author | SHA1 | Date | |
---|---|---|---|
79931365b7 |
@ -114,13 +114,26 @@ def extern(function):
|
||||
|
||||
|
||||
def rpc(arg=None, flags={}):
|
||||
"""Decorates a function or method to be executed on the host interpreter."""
|
||||
"""Decorates a function to be executed on the host interpreter with kwargs support."""
|
||||
def decorator(function):
|
||||
@wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get function signature
|
||||
sig = inspect.signature(function)
|
||||
|
||||
# Validate kwargs against signature
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
# Call RPC with both args and kwargs
|
||||
return _do_rpc(function.__name__,
|
||||
bound_args.args,
|
||||
bound_args.kwargs)
|
||||
return wrapper
|
||||
|
||||
if arg is None:
|
||||
def inner_decorator(function):
|
||||
return rpc(function, flags)
|
||||
return inner_decorator
|
||||
register_function(arg)
|
||||
return arg
|
||||
return decorator
|
||||
return decorator(arg)
|
||||
|
||||
def kernel(function_or_method):
|
||||
"""Decorates a function or method to be executed on the core device."""
|
||||
|
@ -79,8 +79,7 @@ pub struct ArtiqCodeGenerator<'a> {
|
||||
|
||||
/// The [`ParallelMode`] of the current parallel context.
|
||||
///
|
||||
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
|
||||
/// statement, which is used to determine when and how the timeline should be updated.
|
||||
/// The current parallel context refers to the nearest `with` statement, which is used to determine when and how the timeline should be updated.
|
||||
parallel_mode: ParallelMode,
|
||||
}
|
||||
|
||||
@ -373,8 +372,14 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||
fn gen_rpc_tag(
|
||||
ctx: &mut CodeGenContext<'_, '_>,
|
||||
ty: Type,
|
||||
is_kwarg: bool, // Add this parameter
|
||||
buffer: &mut Vec<u8>,
|
||||
) -> Result<(), String> {
|
||||
// Add kwarg marker if needed
|
||||
if is_kwarg {
|
||||
buffer.push(b'k'); // 'k' for keyword argument
|
||||
}
|
||||
|
||||
use nac3core::typecheck::typedef::TypeEnum::*;
|
||||
|
||||
let int32 = ctx.primitives.int32;
|
||||
@ -403,14 +408,14 @@ fn gen_rpc_tag(
|
||||
buffer.push(b't');
|
||||
buffer.push(ty.len() as u8);
|
||||
for ty in ty {
|
||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||
gen_rpc_tag(ctx, *ty, false, buffer)?; // Pass false for is_kwarg
|
||||
}
|
||||
}
|
||||
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
||||
let ty = iter_type_vars(params).next().unwrap().ty;
|
||||
|
||||
buffer.push(b'l');
|
||||
gen_rpc_tag(ctx, ty, buffer)?;
|
||||
gen_rpc_tag(ctx, ty, false, buffer)?; // Pass false for is_kwarg
|
||||
}
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
@ -434,7 +439,7 @@ fn gen_rpc_tag(
|
||||
|
||||
buffer.push(b'a');
|
||||
buffer.push((ndarray_ndims & 0xFF) as u8);
|
||||
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
||||
gen_rpc_tag(ctx, ndarray_dtype, false, buffer)?; // Pass false for is_kwarg
|
||||
}
|
||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||
}
|
||||
@ -808,10 +813,10 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
tag.push(b'O');
|
||||
}
|
||||
for arg in &fun.0.args {
|
||||
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
|
||||
gen_rpc_tag(ctx, arg.ty, false, &mut tag)?; // Pass false for is_kwarg
|
||||
}
|
||||
tag.push(b':');
|
||||
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
|
||||
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
tag.hash(&mut hasher);
|
||||
@ -858,8 +863,17 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
// -- rpc args handling
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
let mut is_keyword_arg = HashMap::new();
|
||||
|
||||
for (key, value) in args {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
if let Some(key_name) = key {
|
||||
mapping.insert(key_name, value);
|
||||
is_keyword_arg.insert(key_name, true);
|
||||
} else {
|
||||
let arg_name = keys.remove(0).name;
|
||||
mapping.insert(arg_name, value);
|
||||
is_keyword_arg.insert(arg_name, false);
|
||||
}
|
||||
}
|
||||
// default value handling
|
||||
for k in keys {
|
||||
@ -901,6 +915,14 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
||||
}
|
||||
|
||||
// Before calling rpc_send/rpc_send_async, add keyword arg info to tag
|
||||
for arg in &fun.0.args {
|
||||
if *is_keyword_arg.get(&arg.name).unwrap_or(&false) {
|
||||
tag.push(b'k'); // Mark as keyword argument
|
||||
}
|
||||
gen_rpc_tag(ctx, arg.ty, true, &mut tag)?; // Pass true for is_kwarg
|
||||
}
|
||||
|
||||
// call
|
||||
if is_async {
|
||||
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
||||
@ -1007,7 +1029,7 @@ pub fn attributes_writeback<'ctx>(
|
||||
if !is_mutable {
|
||||
continue;
|
||||
}
|
||||
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
||||
if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() {
|
||||
attributes.push(name.to_string());
|
||||
let (index, _) = ctx.get_attr_index(ty, *name);
|
||||
values.push((
|
||||
@ -1030,7 +1052,7 @@ pub fn attributes_writeback<'ctx>(
|
||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
||||
let elem_ty = iter_type_vars(params).next().unwrap().ty;
|
||||
|
||||
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
|
||||
if gen_rpc_tag(ctx, elem_ty, false, &mut scratch_buffer).is_ok() {
|
||||
let pydict = PyDict::new(py);
|
||||
pydict.set_item("obj", val)?;
|
||||
host_attributes.append(pydict)?;
|
||||
|
@ -1832,20 +1832,47 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) {
|
||||
if sign.vars.is_empty() {
|
||||
// Build keyword argument map
|
||||
let mut kwargs_map = HashMap::new();
|
||||
for kw in &keywords {
|
||||
if let Some(name) = &kw.node.arg {
|
||||
// Check if keyword arg exists in function signature
|
||||
if !sign.args.iter().any(|arg| arg.name == *name) {
|
||||
return report_error(
|
||||
&format!("Unexpected keyword argument '{}'", name),
|
||||
kw.location,
|
||||
);
|
||||
}
|
||||
kwargs_map.insert(*name, kw.node.value.custom.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that all required args are provided
|
||||
for arg in &sign.args {
|
||||
if arg.default_value.is_none()
|
||||
&& !kwargs_map.contains_key(&arg.name)
|
||||
&& args.len() < sign.args.len()
|
||||
{
|
||||
return report_error(
|
||||
&format!("Missing required argument '{}'", arg.name),
|
||||
location,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let call = Call {
|
||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||
kwargs: keywords
|
||||
.iter()
|
||||
.map(|v| (*v.node.arg.as_ref().unwrap(), v.node.value.custom.unwrap()))
|
||||
.collect(),
|
||||
kwargs: kwargs_map,
|
||||
fun: RefCell::new(None),
|
||||
ret: sign.ret,
|
||||
loc: Some(location),
|
||||
operator_info: None,
|
||||
};
|
||||
|
||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||
})?;
|
||||
|
||||
return Ok(Located {
|
||||
location,
|
||||
custom: Some(sign.ret),
|
||||
@ -1859,7 +1886,7 @@ impl<'a> Inferencer<'a> {
|
||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||
kwargs: keywords
|
||||
.iter()
|
||||
.map(|v| (*v.node.arg.as_ref().unwrap(), v.custom.unwrap()))
|
||||
.filter_map(|v| v.node.arg.map(|name| (name, v.node.value.custom.unwrap())))
|
||||
.collect(),
|
||||
fun: RefCell::new(None),
|
||||
ret,
|
||||
|
Loading…
Reference in New Issue
Block a user