Implement passing of kwarg

This commit is contained in:
ram 2025-02-05 07:20:13 +00:00
parent ec2787aaf6
commit c083245c28
2 changed files with 174 additions and 161 deletions

View File

@ -792,11 +792,11 @@ fn format_rpc_ret<'ctx>(
Some(result) Some(result)
} }
fn rpc_codegen_callback_fn<'ctx>( pub fn rpc_codegen_callback_fn<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
obj: Option<(Type, ValueEnum<'ctx>)>, obj: Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId), fun: (&FunSignature, DefinitionId),
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>, mut args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
is_async: bool, is_async: bool,
) -> Result<Option<BasicValueEnum<'ctx>>, String> { ) -> Result<Option<BasicValueEnum<'ctx>>, String> {
@ -804,136 +804,140 @@ fn rpc_codegen_callback_fn<'ctx>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let size_type = generator.get_size_type(ctx.ctx); let size_type = generator.get_size_type(ctx.ctx);
let ptr_type = int8.ptr_type(AddressSpace::default()); let ptr_type = int8.ptr_type(AddressSpace::default());
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx
.ctx
.struct_type(&[ptr_type.into(), size_type.into()], false)
.ptr_type(AddressSpace::default());
let service_id = int32.const_int(fun.1 .0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags // -- setup rpc tags
let mut tag = Vec::new(); let mut tag = Vec::new();
if obj.is_some() { if obj.is_some() {
tag.push(b'O'); tag.push(b'O');
} }
for arg in &fun.0.args { for param in &fun.0.args {
gen_rpc_tag(ctx, arg.ty, false, &mut tag)?; // Pass false for is_kwarg gen_rpc_tag(ctx, param.ty, false, &mut tag)?;
} }
tag.push(b':'); tag.push(b':');
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?; gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
tag.hash(&mut hasher); tag.hash(&mut hasher);
let hash = format!("{}", hasher.finish()); let hash = format!("rpc_tag_{}", hasher.finish());
let tag_ptr = ctx let maybe_existing = ctx.module.get_global(&hash);
.module let tag_ptr = if let Some(gv) = maybe_existing {
.get_global(hash.as_str()) gv.as_pointer_value()
.unwrap_or_else(|| { } else {
let tag_arr_ptr = ctx.module.add_global( let tag_len = tag.len();
int8.array_type(tag.len() as u32), let arr_ty = int8.array_type(tag_len as u32);
None, let tag_const = int8.const_array(
format!("tagptr{}", fun.1 .0).as_str(), &tag.iter().map(|&b| int8.const_int(b as u64, false)).collect::<Vec<_>>(),
); );
tag_arr_ptr.set_initializer(&int8.const_array( let arr_gv = ctx
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(), .module
)); .add_global(arr_ty, None, &format!("{}.arr", hash));
tag_arr_ptr.set_linkage(Linkage::Private); arr_gv.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); arr_gv.set_initializer(&tag_const);
tag_ptr.set_linkage(Linkage::Private);
tag_ptr.set_initializer(&ctx.ctx.const_struct( let st = ctx.ctx.const_struct(
&[ &[
tag_arr_ptr.as_pointer_value().const_cast(ptr_type).into(), arr_gv.as_pointer_value().const_cast(ptr_type).into(),
size_type.const_int(tag.len() as u64, false).into(), size_type.const_int(tag_len as u64, false).into(),
], ],
false, false,
)); );
tag_ptr let st_gv = ctx.module.add_global(st.get_type(), None, &hash);
}) st_gv.set_linkage(Linkage::Private);
.as_pointer_value(); st_gv.set_initializer(&st);
let arg_length = args.len() + usize::from(obj.is_some()); st_gv.as_pointer_value()
};
let n_params = fun.0.args.len();
let mut param_map: Vec<Option<ValueEnum<'ctx>>> = vec![None; n_params];
let mut pos_index = 0usize;
if let Some((obj_ty, obj_val)) = obj {
param_map[0] = Some(obj_val);
pos_index = 1;
}
for (maybe_key, val_enum) in args.drain(..) {
if let Some(kw_name) = maybe_key {
let param_pos = fun.0
.args
.iter()
.position(|arg| arg.name == kw_name)
.ok_or_else(|| format!("Unknown keyword argument '{}'", kw_name))?;
if param_map[param_pos].is_some() {
return Err(format!("Multiple values for argument '{}'", kw_name));
}
param_map[param_pos] = Some(val_enum);
} else {
while pos_index < n_params && param_map[pos_index].is_some() {
pos_index += 1;
}
if pos_index >= n_params {
return Err("Too many positional arguments given to function.".to_string());
}
param_map[pos_index] = Some(val_enum);
pos_index += 1;
}
}
for (i, param) in fun.0.args.iter().enumerate() {
if param_map[i].is_none() {
if let Some(default_expr) = &param.default_value {
let default_val = ctx.gen_symbol_val(generator, default_expr, param.ty).into();
param_map[i] = Some(default_val);
} else {
return Err(format!("Missing required argument '{}'", param.name));
}
}
}
let mut real_params = Vec::with_capacity(n_params);
for (i, param_spec) in fun.0.args.iter().enumerate() {
let some_valenum = param_map[i].take().unwrap();
let llvm_val = some_valenum.to_basic_value_enum(ctx, generator, param_spec.ty)?;
real_params.push((llvm_val, param_spec.ty));
}
let arg_count = real_params.len() as u64;
let stackptr = call_stacksave(ctx, Some("rpc.stack")); let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let args_ptr = ctx
let i32_ty = ctx.ctx.i32_type();
let arg_array = ctx
.builder .builder
.build_array_alloca( .build_array_alloca(
ptr_type, ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false), i32_ty.const_int(arg_count, false),
"argptr", "rpc.arg_array",
) )
.unwrap(); .unwrap();
// -- rpc args handling for (i, (llvm_val, ty)) in real_params.iter().enumerate() {
let mut keys = fun.0.args.clone(); let arg_slot_ptr = unsafe {
let mut mapping = HashMap::new();
let mut is_keyword_arg = HashMap::new();
for (key, value) in args {
if let Some(key_name) = key {
mapping.insert(key_name, value);
is_keyword_arg.insert(key_name, true);
} else {
let arg_name = keys.remove(0).name;
mapping.insert(arg_name, value);
is_keyword_arg.insert(arg_name, false);
}
}
// default value handling
for k in keys {
mapping
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
}
// reorder the parameters
let mut real_params = fun
.0
.args
.iter()
.map(|arg| {
mapping
.remove(&arg.name)
.unwrap()
.to_basic_value_enum(ctx, generator, arg.ty)
.map(|llvm_val| (llvm_val, arg.ty))
})
.collect::<Result<Vec<(_, _)>, _>>()?;
if let Some(obj) = obj {
if let ValueEnum::Static(obj_val) = obj.1 {
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
} else {
// should be an error here...
panic!("only host object is allowed");
}
}
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
let arg_ptr = unsafe {
ctx.builder.build_gep( ctx.builder.build_gep(
args_ptr, arg_array,
&[int32.const_int(i as u64, false)], &[
&format!("rpc.arg{i}"), i32_ty.const_int(i as u64, false),
],
&format!("rpc.arg_slot_{}", i),
) )
} }.unwrap();
.unwrap(); let arg_ptr = format_rpc_arg(generator, ctx, (*llvm_val, *ty, i));
ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); ctx.builder.build_store(arg_slot_ptr, arg_ptr).unwrap();
} }
// Before calling rpc_send/rpc_send_async, add keyword arg info to tag
for arg in &fun.0.args {
if *is_keyword_arg.get(&arg.name).unwrap_or(&false) {
tag.push(b'k'); // Mark as keyword argument
}
gen_rpc_tag(ctx, arg.ty, true, &mut tag)?; // Pass true for is_kwarg
}
// call
if is_async { if is_async {
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| { let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"rpc_send_async", "rpc_send_async",
ctx.ctx.void_type().fn_type( ctx.ctx.void_type().fn_type(
&[ &[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false, false,
), ),
None, None,
@ -942,45 +946,42 @@ fn rpc_codegen_callback_fn<'ctx>(
ctx.builder ctx.builder
.build_call( .build_call(
rpc_send_async, rpc_send_async,
&[service_id.into(), tag_ptr.into(), args_ptr.into()], &[
"rpc.send", service_id.into(),
tag_ptr.into(),
arg_array.into(),
],
"rpc.send_async",
) )
.unwrap(); .unwrap();
call_stackrestore(ctx, stackptr);
return Ok(None);
} else { } else {
let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| { let rpc_send = ctx.module.get_function("rpc_send").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"rpc_send", "rpc_send",
ctx.ctx.void_type().fn_type( ctx.ctx.void_type().fn_type(
&[ &[int32.into(), tag_ptr_type.into(), ptr_type.ptr_type(AddressSpace::default()).into()],
int32.into(),
tag_ptr_type.ptr_type(AddressSpace::default()).into(),
ptr_type.ptr_type(AddressSpace::default()).into(),
],
false, false,
), ),
None, None,
) )
}); });
ctx.builder ctx.builder
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send") .build_call(
rpc_send,
&[
service_id.into(),
tag_ptr.into(),
arg_array.into(),
],
"rpc.send",
)
.unwrap(); .unwrap();
}
// reclaim stack space used by arguments
call_stackrestore(ctx, stackptr); call_stackrestore(ctx, stackptr);
if is_async { let maybe_ret = format_rpc_ret(generator, ctx, fun.0.ret);
// async RPCs do not return any values Ok(maybe_ret)
Ok(None)
} else {
let result = format_rpc_ret(generator, ctx, fun.0.ret);
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
// An RPC returning an NDArray would not touch here.
call_stackrestore(ctx, stackptr);
}
Ok(result)
} }
} }

View File

@ -3,7 +3,7 @@ use std::{
cmp::max, cmp::max,
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
convert::{From, TryInto}, convert::{From, TryInto},
iter::once, iter::{once, repeat_n},
sync::Arc, sync::Arc,
}; };
@ -187,7 +187,7 @@ fn fix_assignment_target_context(node: &mut ast::Located<ExprKind>) {
} }
} }
impl<'a> Fold<()> for Inferencer<'a> { impl Fold<()> for Inferencer<'_> {
type TargetU = Option<Type>; type TargetU = Option<Type>;
type Error = InferenceError; type Error = InferenceError;
@ -657,7 +657,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
type InferenceResult = Result<Type, InferenceError>; type InferenceResult = Result<Type, InferenceError>;
impl<'a> Inferencer<'a> { impl Inferencer<'_> {
/// Constrain a <: b /// Constrain a <: b
/// Currently implemented as unification /// Currently implemented as unification
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> { fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
@ -1234,6 +1234,45 @@ impl<'a> Inferencer<'a> {
})); }));
} }
if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 {
let ndarray = self.fold_expr(args.remove(0))?;
let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap());
// Make a tuple of size `ndims` full of int32 (TODO: Make it usize)
let ret_ty = TypeEnum::TTuple {
ty: repeat_n(self.primitives.int32, ndims as usize).collect_vec(),
is_vararg_ctx: false,
};
let ret_ty = self.unifier.add_ty(ret_ty);
let func_ty = TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "a".into(),
default_value: None,
ty: ndarray.custom.unwrap(),
is_vararg: false,
}],
ret: ret_ty,
vars: VarMap::new(),
});
let func_ty = self.unifier.add_ty(func_ty);
return Ok(Some(Located {
location,
custom: Some(ret_ty),
node: ExprKind::Call {
func: Box::new(Located {
custom: Some(func_ty),
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![ndarray],
keywords: vec![],
},
}));
}
if id == &"np_dot".into() { if id == &"np_dot".into() {
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let arg1 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?;
@ -1555,7 +1594,7 @@ impl<'a> Inferencer<'a> {
})); }));
} }
// 2-argument ndarray n-dimensional factory functions // 2-argument ndarray n-dimensional factory functions
if id == &"np_reshape".into() && args.len() == 2 { if ["np_reshape".into(), "np_broadcast_to".into()].contains(id) && args.len() == 2 {
let arg0 = self.fold_expr(args.remove(0))?; let arg0 = self.fold_expr(args.remove(0))?;
let shape_expr = args.remove(0); let shape_expr = args.remove(0);
@ -1832,47 +1871,20 @@ impl<'a> Inferencer<'a> {
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) {
if sign.vars.is_empty() { if sign.vars.is_empty() {
// Build keyword argument map
let mut kwargs_map = HashMap::new();
for kw in &keywords {
if let Some(name) = &kw.node.arg {
// Check if keyword arg exists in function signature
if !sign.args.iter().any(|arg| arg.name == *name) {
return report_error(
&format!("Unexpected keyword argument '{}'", name),
kw.location,
);
}
kwargs_map.insert(*name, kw.node.value.custom.unwrap());
}
}
// Validate that all required args are provided
for arg in &sign.args {
if arg.default_value.is_none()
&& !kwargs_map.contains_key(&arg.name)
&& args.len() < sign.args.len()
{
return report_error(
&format!("Missing required argument '{}'", arg.name),
location,
);
}
}
let call = Call { let call = Call {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(), posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: kwargs_map, kwargs: keywords
.iter()
.map(|v| (*v.node.arg.as_ref().unwrap(), v.node.value.custom.unwrap()))
.collect(),
fun: RefCell::new(None), fun: RefCell::new(None),
ret: sign.ret, ret: sign.ret,
loc: Some(location), loc: Some(location),
operator_info: None, operator_info: None,
}; };
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
})?; })?;
return Ok(Located { return Ok(Located {
location, location,
custom: Some(sign.ret), custom: Some(sign.ret),
@ -1886,7 +1898,7 @@ impl<'a> Inferencer<'a> {
posargs: args.iter().map(|v| v.custom.unwrap()).collect(), posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
kwargs: keywords kwargs: keywords
.iter() .iter()
.filter_map(|v| v.node.arg.map(|name| (name, v.node.value.custom.unwrap()))) .map(|v| (*v.node.arg.as_ref().unwrap(), v.custom.unwrap()))
.collect(), .collect(),
fun: RefCell::new(None), fun: RefCell::new(None),
ret, ret,