Implement passing of kwarg
This commit is contained in:
parent
ec2787aaf6
commit
c083245c28
@ -792,11 +792,11 @@ fn format_rpc_ret<'ctx>(
|
||||
Some(result)
|
||||
}
|
||||
|
||||
fn rpc_codegen_callback_fn<'ctx>(
|
||||
pub fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
mut args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
is_async: bool,
|
||||
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
|
||||
@ -804,136 +804,140 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let ptr_type = int8.ptr_type(AddressSpace::default());
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
let tag_ptr_type = ctx
|
||||
.ctx
|
||||
.struct_type(&[ptr_type.into(), size_type.into()], false)
|
||||
.ptr_type(AddressSpace::default());
|
||||
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
||||
// -- setup rpc tags
|
||||
let mut tag = Vec::new();
|
||||
if obj.is_some() {
|
||||
tag.push(b'O');
|
||||
}
|
||||
for arg in &fun.0.args {
|
||||
gen_rpc_tag(ctx, arg.ty, false, &mut tag)?; // Pass false for is_kwarg
|
||||
for param in &fun.0.args {
|
||||
gen_rpc_tag(ctx, param.ty, false, &mut tag)?;
|
||||
}
|
||||
tag.push(b':');
|
||||
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
|
||||
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
tag.hash(&mut hasher);
|
||||
let hash = format!("{}", hasher.finish());
|
||||
let hash = format!("rpc_tag_{}", hasher.finish());
|
||||
|
||||
let tag_ptr = ctx
|
||||
.module
|
||||
.get_global(hash.as_str())
|
||||
.unwrap_or_else(|| {
|
||||
let tag_arr_ptr = ctx.module.add_global(
|
||||
int8.array_type(tag.len() as u32),
|
||||
None,
|
||||
format!("tagptr{}", fun.1 .0).as_str(),
|
||||
);
|
||||
tag_arr_ptr.set_initializer(&int8.const_array(
|
||||
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
|
||||
));
|
||||
tag_arr_ptr.set_linkage(Linkage::Private);
|
||||
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
|
||||
tag_ptr.set_linkage(Linkage::Private);
|
||||
tag_ptr.set_initializer(&ctx.ctx.const_struct(
|
||||
&[
|
||||
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
|
||||
size_type.const_int(tag.len() as u64, false).into(),
|
||||
],
|
||||
false,
|
||||
));
|
||||
tag_ptr
|
||||
})
|
||||
.as_pointer_value();
|
||||
let maybe_existing = ctx.module.get_global(&hash);
|
||||
let tag_ptr = if let Some(gv) = maybe_existing {
|
||||
gv.as_pointer_value()
|
||||
} else {
|
||||
let tag_len = tag.len();
|
||||
let arr_ty = int8.array_type(tag_len as u32);
|
||||
let tag_const = int8.const_array(
|
||||
&tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::<Vec<_>>(),
|
||||
);
|
||||
let arr_gv = ctx
|
||||
.module
|
||||
.add_global(arr_ty, None, &format!("{}.arr", hash));
|
||||
arr_gv.set_linkage(Linkage::Private);
|
||||
arr_gv.set_initializer(&tag_const);
|
||||
|
||||
let arg_length = args.len() + usize::from(obj.is_some());
|
||||
let st = ctx.ctx.const_struct(
|
||||
&[
|
||||
arr_gv.as_pointer_value().const_cast(ptr_type).into(),
|
||||
size_type.const_int(tag_len as u64, false).into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
let st_gv = ctx.module.add_global(st.get_type(), None, &hash);
|
||||
st_gv.set_linkage(Linkage::Private);
|
||||
st_gv.set_initializer(&st);
|
||||
|
||||
st_gv.as_pointer_value()
|
||||
};
|
||||
|
||||
let n_params = fun.0.args.len();
|
||||
let mut param_map: Vec<Option<ValueEnum<'ctx>>> = vec![None; n_params];
|
||||
|
||||
let mut pos_index = 0usize;
|
||||
|
||||
if let Some((obj_ty, obj_val)) = obj {
|
||||
param_map[0] = Some(obj_val);
|
||||
pos_index = 1;
|
||||
}
|
||||
for (maybe_key, val_enum) in args.drain(..) {
|
||||
if let Some(kw_name) = maybe_key {
|
||||
let param_pos = fun.0
|
||||
.args
|
||||
.iter()
|
||||
.position(|arg| arg.name == kw_name)
|
||||
.ok_or_else(|| format!("Unknown keyword argument '{}'", kw_name))?;
|
||||
|
||||
if param_map[param_pos].is_some() {
|
||||
return Err(format!("Multiple values for argument '{}'", kw_name));
|
||||
}
|
||||
param_map[param_pos] = Some(val_enum);
|
||||
} else {
|
||||
while pos_index < n_params && param_map[pos_index].is_some() {
|
||||
pos_index += 1;
|
||||
}
|
||||
if pos_index >= n_params {
|
||||
return Err("Too many positional arguments given to function.".to_string());
|
||||
}
|
||||
param_map[pos_index] = Some(val_enum);
|
||||
pos_index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for (i, param) in fun.0.args.iter().enumerate() {
|
||||
if param_map[i].is_none() {
|
||||
if let Some(default_expr) = ¶m.default_value {
|
||||
let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into();
|
||||
param_map[i] = Some(default_val);
|
||||
} else {
|
||||
return Err(format!("Missing required argument '{}'", param.name));
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut real_params = Vec::with_capacity(n_params);
|
||||
for (i, param_spec) in fun.0.args.iter().enumerate() {
|
||||
let some_valenum = param_map[i].take().unwrap();
|
||||
let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?;
|
||||
real_params.push((llvm_val, param_spec.ty));
|
||||
}
|
||||
|
||||
let arg_count = real_params.len() as u64;
|
||||
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
|
||||
let args_ptr = ctx
|
||||
|
||||
let i32_ty = ctx.ctx.i32_type();
|
||||
let arg_array = ctx
|
||||
.builder
|
||||
.build_array_alloca(
|
||||
ptr_type,
|
||||
ctx.ctx.i32_type().const_int(arg_length as u64, false),
|
||||
"argptr",
|
||||
i32_ty.const_int(arg_count, false),
|
||||
"rpc.arg_array",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// -- rpc args handling
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
let mut is_keyword_arg = HashMap::new();
|
||||
|
||||
for (key, value) in args {
|
||||
if let Some(key_name) = key {
|
||||
mapping.insert(key_name, value);
|
||||
is_keyword_arg.insert(key_name, true);
|
||||
} else {
|
||||
let arg_name = keys.remove(0).name;
|
||||
mapping.insert(arg_name, value);
|
||||
is_keyword_arg.insert(arg_name, false);
|
||||
}
|
||||
}
|
||||
// default value handling
|
||||
for k in keys {
|
||||
mapping
|
||||
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
|
||||
}
|
||||
// reorder the parameters
|
||||
let mut real_params = fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
mapping
|
||||
.remove(&arg.name)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, arg.ty)
|
||||
.map(|llvm_val| (llvm_val, arg.ty))
|
||||
})
|
||||
.collect::<Result<Vec<(_, _)>, _>>()?;
|
||||
if let Some(obj) = obj {
|
||||
if let ValueEnum::Static(obj_val) = obj.1 {
|
||||
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
|
||||
} else {
|
||||
// should be an error here...
|
||||
panic!("only host object is allowed");
|
||||
}
|
||||
}
|
||||
|
||||
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
|
||||
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
|
||||
let arg_ptr = unsafe {
|
||||
for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
|
||||
let arg_slot_ptr = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
args_ptr,
|
||||
&[int32.const_int(i as u64, false)],
|
||||
&format!("rpc.arg{i}"),
|
||||
arg_array,
|
||||
&[
|
||||
i32_ty.const_int(i as u64, false),
|
||||
],
|
||||
&format!("rpc.arg_slot_{}", i),
|
||||
)
|
||||
}
|
||||
.unwrap();
|
||||
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
||||
}.unwrap();
|
||||
let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
|
||||
ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
|
||||
}
|
||||
|
||||
// Before calling rpc_send/rpc_send_async, add keyword arg info to tag
|
||||
for arg in &fun.0.args {
|
||||
if *is_keyword_arg.get(&arg.name).unwrap_or(&false) {
|
||||
tag.push(b'k'); // Mark as keyword argument
|
||||
}
|
||||
gen_rpc_tag(ctx, arg.ty, true, &mut tag)?; // Pass true for is_kwarg
|
||||
}
|
||||
|
||||
// call
|
||||
if is_async {
|
||||
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
"rpc_send_async",
|
||||
ctx.ctx.void_type().fn_type(
|
||||
&[
|
||||
int32.into(),
|
||||
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
],
|
||||
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
@ -942,45 +946,42 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
ctx.builder
|
||||
.build_call(
|
||||
rpc_send_async,
|
||||
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
|
||||
"rpc.send",
|
||||
&[
|
||||
service_id.into(),
|
||||
tag_ptr.into(),
|
||||
arg_array.into(),
|
||||
],
|
||||
"rpc.send_async",
|
||||
)
|
||||
.unwrap();
|
||||
call_stackrestore(ctx, stackptr);
|
||||
return Ok(None);
|
||||
} else {
|
||||
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
|
||||
ctx.module.add_function(
|
||||
"rpc_send",
|
||||
ctx.ctx.void_type().fn_type(
|
||||
&[
|
||||
int32.into(),
|
||||
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
ptr_type.ptr_type(AddressSpace::default()).into(),
|
||||
],
|
||||
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
|
||||
false,
|
||||
),
|
||||
None,
|
||||
)
|
||||
});
|
||||
ctx.builder
|
||||
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
|
||||
.build_call(
|
||||
rpc_send,
|
||||
&[
|
||||
service_id.into(),
|
||||
tag_ptr.into(),
|
||||
arg_array.into(),
|
||||
],
|
||||
"rpc.send",
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
// reclaim stack space used by arguments
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
if is_async {
|
||||
// async RPCs do not return any values
|
||||
Ok(None)
|
||||
} else {
|
||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||
|
||||
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||
// An RPC returning an NDArray would not touch here.
|
||||
call_stackrestore(ctx, stackptr);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
let maybe_ret = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||
Ok(maybe_ret)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,7 +3,7 @@ use std::{
|
||||
cmp::max,
|
||||
collections::{HashMap, HashSet},
|
||||
convert::{From, TryInto},
|
||||
iter::once,
|
||||
iter::{once, repeat_n},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
@ -187,7 +187,7 @@ fn fix_assignment_target_context(node: &mut ast::Located<ExprKind>) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Fold<()> for Inferencer<'a> {
|
||||
impl Fold<()> for Inferencer<'_> {
|
||||
type TargetU = Option<Type>;
|
||||
type Error = InferenceError;
|
||||
|
||||
@ -657,7 +657,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||
|
||||
type InferenceResult = Result<Type, InferenceError>;
|
||||
|
||||
impl<'a> Inferencer<'a> {
|
||||
impl Inferencer<'_> {
|
||||
/// Constrain a <: b
|
||||
/// Currently implemented as unification
|
||||
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
|
||||
@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> {
|
||||
}));
|
||||
}
|
||||
|
||||
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
|
||||
let ndarray = self.fold_expr(args.remove(0))?;
|
||||
|
||||
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
|
||||
|
||||
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
|
||||
let ret_ty = TypeEnum::TTuple {
|
||||
ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(),
|
||||
is_vararg_ctx: false,
|
||||
};
|
||||
let ret_ty = self.unifier.add_ty(ret_ty);
|
||||
|
||||
let func_ty = TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg {
|
||||
name: "a".into(),
|
||||
default_value: None,
|
||||
ty: ndarray.custom.unwrap(),
|
||||
is_vararg: false,
|
||||
}],
|
||||
ret: ret_ty,
|
||||
vars: VarMap::new(),
|
||||
});
|
||||
let func_ty = self.unifier.add_ty(func_ty);
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret_ty),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(func_ty),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||
}),
|
||||
args: vec![ndarray],
|
||||
keywords: vec![],
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
if id == &"np_dot".into() {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg1 = self.fold_expr(args.remove(0))?;
|
||||
@ -1555,7 +1594,7 @@ impl<'a> Inferencer<'a> {
|
||||
}));
|
||||
}
|
||||
// 2-argument ndarray n-dimensional factory functions
|
||||
if id == &"np_reshape".into() && args.len() == 2 {
|
||||
if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
|
||||
let shape_expr = args.remove(0);
|
||||
@ -1832,47 +1871,20 @@ impl<'a> Inferencer<'a> {
|
||||
|
||||
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) {
|
||||
if sign.vars.is_empty() {
|
||||
// Build keyword argument map
|
||||
let mut kwargs_map = HashMap::new();
|
||||
for kw in &keywords {
|
||||
if let Some(name) = &kw.node.arg {
|
||||
// Check if keyword arg exists in function signature
|
||||
if !sign.args.iter().any(|arg| arg.name == *name) {
|
||||
return report_error(
|
||||
&format!("Unexpected keyword argument '{}'", name),
|
||||
kw.location,
|
||||
);
|
||||
}
|
||||
kwargs_map.insert(*name, kw.node.value.custom.unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that all required args are provided
|
||||
for arg in &sign.args {
|
||||
if arg.default_value.is_none()
|
||||
&& !kwargs_map.contains_key(&arg.name)
|
||||
&& args.len() < sign.args.len()
|
||||
{
|
||||
return report_error(
|
||||
&format!("Missing required argument '{}'", arg.name),
|
||||
location,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let call = Call {
|
||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||
kwargs: kwargs_map,
|
||||
kwargs: keywords
|
||||
.iter()
|
||||
.map(|v| (*v.node.arg.as_ref().unwrap(), v.node.value.custom.unwrap()))
|
||||
.collect(),
|
||||
fun: RefCell::new(None),
|
||||
ret: sign.ret,
|
||||
loc: Some(location),
|
||||
operator_info: None,
|
||||
};
|
||||
|
||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||
})?;
|
||||
|
||||
return Ok(Located {
|
||||
location,
|
||||
custom: Some(sign.ret),
|
||||
@ -1886,7 +1898,7 @@ impl<'a> Inferencer<'a> {
|
||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||
kwargs: keywords
|
||||
.iter()
|
||||
.filter_map(|v| v.node.arg.map(|name| (name, v.node.value.custom.unwrap())))
|
||||
.map(|v| (*v.node.arg.as_ref().unwrap(), v.custom.unwrap()))
|
||||
.collect(),
|
||||
fun: RefCell::new(None),
|
||||
ret,
|
||||
|
Loading…
Reference in New Issue
Block a user