Implement passing of kwarg

This commit is contained in:
ram 2025-02-05 07:20:13 +00:00
parent ec2787aaf6
commit c083245c28
2 changed files with 174 additions and 161 deletions

View File

@ -792,11 +792,11 @@ fn format_rpc_ret<'ctx>(
Some(result)
}
fn rpc_codegen_callback_fn<'ctx>(
pub fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
mut args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator,
is_async: bool,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
@ -804,136 +804,140 @@ fn rpc_codegen_callback_fn<'ctx>(
let int32 = ctx.ctx.i32_type();
let size_type = generator.get_size_type(ctx.ctx);
let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let tag_ptr_type = ctx
.ctx
.struct_type(&[ptr_type.into(), size_type.into()], false)
.ptr_type(AddressSpace::default());
let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags
let mut tag = Vec::new();
if obj.is_some() {
tag.push(b'O');
}
for arg in &fun.0.args {
gen_rpc_tag(ctx, arg.ty, false, &mut tag)?; // Pass false for is_kwarg
for param in &fun.0.args {
gen_rpc_tag(ctx, param.ty, false, &mut tag)?;
}
tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish());
let hash = format!("rpc_tag_{}", hasher.finish());
let tag_ptr = ctx
.module
.get_global(hash.as_str())
.unwrap_or_else(|| {
let tag_arr_ptr = ctx.module.add_global(
int8.array_type(tag.len() as u32),
None,
format!("tagptr{}", fun.1 .0).as_str(),
);
tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
));
tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct(
&[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(),
],
false,
));
tag_ptr
})
.as_pointer_value();
let maybe_existing = ctx.module.get_global(&hash);
let tag_ptr = if let Some(gv) = maybe_existing {
gv.as_pointer_value()
} else {
let tag_len = tag.len();
let arr_ty = int8.array_type(tag_len as u32);
let tag_const = int8.const_array(
&tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::<Vec<_>>(),
);
let arr_gv = ctx
.module
.add_global(arr_ty, None, &format!("{}.arr", hash));
arr_gv.set_linkage(Linkage::Private);
arr_gv.set_initializer(&tag_const);
let arg_length = args.len() + usize::from(obj.is_some());
let st = ctx.ctx.const_struct(
&[
arr_gv.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag_len as u64, false).into(),
],
false,
);
let st_gv = ctx.module.add_global(st.get_type(), None, &hash);
st_gv.set_linkage(Linkage::Private);
st_gv.set_initializer(&st);
st_gv.as_pointer_value()
};
let n_params = fun.0.args.len();
let mut param_map: Vec<Option<ValueEnum<'ctx>>> = vec![None; n_params];
let mut pos_index = 0usize;
if let Some((obj_ty, obj_val)) = obj {
param_map[0] = Some(obj_val);
pos_index = 1;
}
for (maybe_key, val_enum) in args.drain(..) {
if let Some(kw_name) = maybe_key {
let param_pos = fun.0
.args
.iter()
.position(|arg| arg.name == kw_name)
.ok_or_else(|| format!("Unknown keyword argument '{}'", kw_name))?;
if param_map[param_pos].is_some() {
return Err(format!("Multiple values for argument '{}'", kw_name));
}
param_map[param_pos] = Some(val_enum);
} else {
while pos_index < n_params && param_map[pos_index].is_some() {
pos_index += 1;
}
if pos_index >= n_params {
return Err("Too many positional arguments given to function.".to_string());
}
param_map[pos_index] = Some(val_enum);
pos_index += 1;
}
}
for (i, param) in fun.0.args.iter().enumerate() {
if param_map[i].is_none() {
if let Some(default_expr) = &param.default_value {
let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into();
param_map[i] = Some(default_val);
} else {
return Err(format!("Missing required argument '{}'", param.name));
}
}
}
let mut real_params = Vec::with_capacity(n_params);
for (i, param_spec) in fun.0.args.iter().enumerate() {
let some_valenum = param_map[i].take().unwrap();
let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?;
real_params.push((llvm_val, param_spec.ty));
}
let arg_count = real_params.len() as u64;
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let args_ptr = ctx
let i32_ty = ctx.ctx.i32_type();
let arg_array = ctx
.builder
.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
"argptr",
i32_ty.const_int(arg_count, false),
"rpc.arg_array",
)
.unwrap();
// -- rpc args handling
let mut keys = fun.0.args.clone();
let mut mapping = HashMap::new();
let mut is_keyword_arg = HashMap::new();
for (key, value) in args {
if let Some(key_name) = key {
mapping.insert(key_name, value);
is_keyword_arg.insert(key_name, true);
} else {
let arg_name = keys.remove(0).name;
mapping.insert(arg_name, value);
is_keyword_arg.insert(arg_name, false);
}
}
// default value handling
for k in keys {
mapping
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
}
// reorder the parameters
let mut real_params = fun
.0
.args
.iter()
.map(|arg| {
mapping
.remove(&arg.name)
.unwrap()
.to_basic_value_enum(ctx, generator, arg.ty)
.map(|llvm_val| (llvm_val, arg.ty))
})
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
let arg_ptr = unsafe {
for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
let arg_slot_ptr = unsafe {
ctx.builder.build_gep(
args_ptr,
&[int32.const_int(i as u64, false)],
&format!("rpc.arg{i}"),
arg_array,
&[
i32_ty.const_int(i as u64, false),
],
&format!("rpc.arg_slot_{}", i),
)
}
.unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
}.unwrap();
let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
}
// Before calling rpc_send/rpc_send_async, add keyword arg info to tag
for arg in &fun.0.args {
if *is_keyword_arg.get(&arg.name).unwrap_or(&false) {
tag.push(b'k'); // Mark as keyword argument
}
gen_rpc_tag(ctx, arg.ty, true, &mut tag)?; // Pass true for is_kwarg
}
// call
if is_async {
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send_async",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
false,
),
None,
@ -942,45 +946,42 @@ fn rpc_codegen_callback_fn<'ctx>(
ctx.builder
.build_call(
rpc_send_async,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
&[
service_id.into(),
tag_ptr.into(),
arg_array.into(),
],
"rpc.send_async",
)
.unwrap();
call_stackrestore(ctx, stackptr);
return Ok(None);
} else {
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function(
"rpc_send",
ctx.ctx.void_type().fn_type(
&[
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
&[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
false,
),
None,
)
});
ctx.builder
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.build_call(
rpc_send,
&[
service_id.into(),
tag_ptr.into(),
arg_array.into(),
],
"rpc.send",
)
.unwrap();
}
call_stackrestore(ctx, stackptr);
// reclaim stack space used by arguments
call_stackrestore(ctx, stackptr);
if is_async {
// async RPCs do not return any values
Ok(None)
} else {
let result = format_rpc_ret(generator, ctx, fun.0.ret);
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
// An RPC returning an NDArray would not touch here.
call_stackrestore(ctx, stackptr);
}
Ok(result)
let maybe_ret = format_rpc_ret(generator, ctx, fun.0.ret);
Ok(maybe_ret)
}
}

View File

@ -3,7 +3,7 @@ use std::{
cmp::max,
collections::{HashMap, HashSet},
convert::{From, TryInto},
iter::once,
iter::{once, repeat_n},
sync::Arc,
};
@ -187,7 +187,7 @@ fn fix_assignment_target_context(node: &mut ast::Located<ExprKind>) {
}
}
impl<'a> Fold<()> for Inferencer<'a> {
impl Fold<()> for Inferencer<'_> {
type TargetU = Option<Type>;
type Error = InferenceError;
@ -657,7 +657,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
type InferenceResult = Result<Type, InferenceError>;
impl<'a> Inferencer<'a> {
impl Inferencer<'_> {
/// Constrain a <: b
/// Currently implemented as unification
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> {
}));
}
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
let ndarray = self.fold_expr(args.remove(0))?;
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
let ret_ty = TypeEnum::TTuple {
ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(),
is_vararg_ctx: false,
};
let ret_ty = self.unifier.add_ty(ret_ty);
let func_ty = TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "a".into(),
default_value: None,
ty: ndarray.custom.unwrap(),
is_vararg: false,
}],
ret: ret_ty,
vars: VarMap::new(),
});
let func_ty = self.unifier.add_ty(func_ty);
return Ok(Some(Located {
location,
custom: Some(ret_ty),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(func_ty),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![ndarray],
keywords: vec![],
},
}));
}
if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?;
@ -1555,7 +1594,7 @@ impl<'a> Inferencer<'a> {
}));
}
// 2-argument ndarray n-dimensional factory functions
if id == &"np_reshape".into() && args.len() == 2 {
if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 {
let arg0 = self.fold_expr(args.remove(0))?;
let shape_expr = args.remove(0);
@ -1832,47 +1871,20 @@ impl<'a> Inferencer<'a> {
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) {
if sign.vars.is_empty() {
// Build keyword argument map
let mut kwargs_map = HashMap::new();
for kw in &keywords {
if let Some(name) = &kw.node.arg {
// Check if keyword arg exists in function signature
if !sign.args.iter().any(|arg| arg.name == *name) {
return report_error(
&format!("Unexpected keyword argument '{}'", name),
kw.location,
);
}
kwargs_map.insert(*name, kw.node.value.custom.unwrap());
}
}
// Validate that all required args are provided
for arg in &sign.args {
if arg.default_value.is_none()
&& !kwargs_map.contains_key(&arg.name)
&& args.len() < sign.args.len()
{
return report_error(
&format!("Missing required argument '{}'", arg.name),
location,
);
}
}
let call = Call {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: kwargs_map,
kwargs: keywords
.iter()
.map(|v| (*v.node.arg.as_ref().unwrap(), v.node.value.custom.unwrap()))
.collect(),
fun: RefCell::new(None),
ret: sign.ret,
loc: Some(location),
operator_info: None,
};
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
})?;
return Ok(Located {
location,
custom: Some(sign.ret),
@ -1886,7 +1898,7 @@ impl<'a> Inferencer<'a> {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: keywords
.iter()
.filter_map(|v| v.node.arg.map(|name| (name, v.node.value.custom.unwrap())))
.map(|v| (*v.node.arg.as_ref().unwrap(), v.custom.unwrap()))
.collect(),
fun: RefCell::new(None),
ret,