forked from M-Labs/nac3
Compare commits
32 Commits
b0526ba29f
...
6a0377642f
Author | SHA1 | Date | |
---|---|---|---|
|
6a0377642f | ||
ff10747f19 | |||
|
492645c0cc | ||
b5464831b4 | |||
669f58cc44 | |||
2a78302fa9 | |||
4b3cf9fe46 | |||
a3d7fdd83d | |||
d7a254c758 | |||
776cf8c135 | |||
bf3e6b4096 | |||
ca4cf9f443 | |||
32637937f2 | |||
a6b35de35b | |||
b5f13235c7 | |||
|
659ef7ee90 | ||
12f0717308 | |||
|
78bdfaac50 | ||
|
0784d2f4a9 | ||
|
a5c550b86b | ||
c598d57024 | |||
ad17bf7914 | |||
b21562cfd8 | |||
569c3f1568 | |||
945818aea0 | |||
1cf40455fc | |||
d0efc1419b | |||
521f742916 | |||
cabb20e7f9 | |||
f31be04f4d | |||
5fdc9c8f69 | |||
b009a66ab9 |
@ -4,6 +4,8 @@ use std::{
|
||||
iter::once,
|
||||
mem,
|
||||
sync::Arc,
|
||||
io,
|
||||
io::*,
|
||||
};
|
||||
|
||||
use itertools::Itertools;
|
||||
@ -516,8 +518,16 @@ fn format_rpc_arg<'ctx>(
|
||||
ctx.builder.build_int_truncate_or_bit_cast(sizeof_pdata, llvm_usize, "").unwrap();
|
||||
|
||||
let sizeof_buf_shape = ctx.builder.build_int_mul(sizeof_usize, ndims, "").unwrap();
|
||||
let sizeof_buf = ctx.builder.build_int_add(sizeof_buf_shape, sizeof_pdata, "").unwrap();
|
||||
|
||||
let alignment = llvm_usize.const_int(8, false);
|
||||
let sizeof_buf = ctx
|
||||
.builder
|
||||
.build_int_add(
|
||||
sizeof_buf_shape,
|
||||
ctx.builder.build_int_add(sizeof_pdata, alignment, "").unwrap(),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
// buf = { data: void*, shape: [size_t; ndims]; }
|
||||
let buf = ctx.builder.build_array_alloca(llvm_i8, sizeof_buf, "rpc.arg").unwrap();
|
||||
let buf = ArraySliceValue::from_ptr_val(buf, sizeof_buf, Some("rpc.arg"));
|
||||
@ -528,12 +538,46 @@ fn format_rpc_arg<'ctx>(
|
||||
// Write to `buf->data`
|
||||
let carray_data = carray.load_data(ctx);
|
||||
let carray_data = ctx.builder.build_pointer_cast(carray_data, llvm_pi8, "").unwrap();
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let dst_size = sizeof_pdata;
|
||||
let src_size = sizeof_pdata;
|
||||
let cmp = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::ULE, src_size, dst_size, "buffer_size_check1")
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
cmp,
|
||||
"0:AssertionError",
|
||||
"Buffer overflow risk in RPC data copy: source size {0} exceeds destination size {1}",
|
||||
[Some(src_size), Some(dst_size), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
call_memcpy(ctx, buf_data, carray_data, sizeof_pdata);
|
||||
|
||||
// Write to `buf->shape`
|
||||
let carray_shape = ndarray.shape().base_ptr(ctx, generator);
|
||||
let carray_shape_i8 =
|
||||
ctx.builder.build_pointer_cast(carray_shape, llvm_pi8, "").unwrap();
|
||||
// Safety check for buffer overflow
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
let dst_size = sizeof_buf_shape;
|
||||
let src_size = sizeof_buf_shape;
|
||||
let cmp = ctx
|
||||
.builder
|
||||
.build_int_compare(IntPredicate::ULE, src_size, dst_size, "buffer_size_check2")
|
||||
.unwrap();
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
cmp,
|
||||
"0:AssertionError",
|
||||
"Buffer overflow risk in RPC shape copy: source size {0} exceeds destination size {1}",
|
||||
[Some(src_size), Some(dst_size), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
call_memcpy(ctx, buf_shape, carray_shape_i8, sizeof_buf_shape);
|
||||
|
||||
buf.base_ptr(ctx, generator)
|
||||
@ -562,6 +606,7 @@ fn format_rpc_ret<'ctx>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ret_ty: Type,
|
||||
is_async: bool,
|
||||
) -> Option<BasicValueEnum<'ctx>> {
|
||||
// -- receive value:
|
||||
// T result = {
|
||||
@ -589,6 +634,7 @@ fn format_rpc_ret<'ctx>(
|
||||
return None;
|
||||
}
|
||||
|
||||
let stackptr = call_stacksave(ctx, Some("rpc.stack.ret"));
|
||||
let prehead_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let current_function = prehead_bb.get_parent().unwrap();
|
||||
let head_bb = ctx.ctx.append_basic_block(current_function, "rpc.head");
|
||||
@ -659,14 +705,9 @@ fn format_rpc_ret<'ctx>(
|
||||
let unaligned_buffer_size =
|
||||
ctx.builder.build_int_add(sizeof_ptr, sizeof_shape, "").unwrap();
|
||||
|
||||
let stackptr = call_stacksave(ctx, None);
|
||||
let buffer = type_aligned_alloca(
|
||||
generator,
|
||||
ctx,
|
||||
llvm_i8_8,
|
||||
unaligned_buffer_size,
|
||||
Some("rpc.buffer"),
|
||||
);
|
||||
let stackptr = call_stacksave(ctx, Some("rpc.stack.ret"));
|
||||
let buffer =
|
||||
type_aligned_alloca(generator, ctx, llvm_i8_8, sizeof_ptr, Some("rpc.buffer"));
|
||||
let buffer = ArraySliceValue::from_ptr_val(buffer, unaligned_buffer_size, None);
|
||||
|
||||
// The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape]
|
||||
@ -817,6 +858,9 @@ fn format_rpc_ret<'ctx>(
|
||||
}
|
||||
};
|
||||
|
||||
if !is_async && !result.get_type().is_pointer_type() {
|
||||
call_stackrestore(ctx, stackptr);
|
||||
}
|
||||
Some(result)
|
||||
}
|
||||
|
||||
@ -835,32 +879,51 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
||||
// -- setup rpc tags
|
||||
// build the RPC tag with keyword
|
||||
let mut tag = Vec::new();
|
||||
|
||||
if obj.is_some() {
|
||||
tag.push(b'O');
|
||||
}
|
||||
for arg in &fun.0.args {
|
||||
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
|
||||
}
|
||||
tag.push(b'|');
|
||||
for arg in &fun.0.args {
|
||||
let name_string = arg.name.to_string();
|
||||
let name_bytes = name_string.as_bytes();
|
||||
if name_bytes.len() > 255 {
|
||||
return Err(format!("Parameter name too long: '{}'", arg.name));
|
||||
}
|
||||
tag.push(name_bytes.len() as u8);
|
||||
tag.extend_from_slice(name_bytes);
|
||||
}
|
||||
tag.push(b':');
|
||||
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
|
||||
|
||||
let marker = b'K';
|
||||
if obj.is_some() {
|
||||
tag.insert(1, marker);
|
||||
} else {
|
||||
tag.insert(0, marker);
|
||||
}
|
||||
println!("Constructed RPC tag: {tag:?}");
|
||||
io::stdout().flush().unwrap();
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
tag.hash(&mut hasher);
|
||||
let hash = format!("{}", hasher.finish());
|
||||
|
||||
let tag_ptr = ctx
|
||||
.module
|
||||
.get_global(hash.as_str())
|
||||
.get_global(&hash)
|
||||
.unwrap_or_else(|| {
|
||||
let tag_arr_ptr = ctx.module.add_global(
|
||||
int8.array_type(tag.len() as u32),
|
||||
None,
|
||||
format!("tagptr{}", fun.1 .0).as_str(),
|
||||
&format!("tag_array_{hash}"),
|
||||
);
|
||||
tag_arr_ptr.set_initializer(&int8.const_array(
|
||||
&tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
|
||||
&tag.iter().map(|&b| int8.const_int(u64::from(b), false)).collect::<Vec<_>>(),
|
||||
));
|
||||
tag_arr_ptr.set_linkage(Linkage::Private);
|
||||
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
|
||||
@ -877,7 +940,6 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
.as_pointer_value();
|
||||
|
||||
let arg_length = args.len() + usize::from(obj.is_some());
|
||||
|
||||
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
|
||||
let args_ptr = ctx
|
||||
.builder
|
||||
@ -888,39 +950,53 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// -- rpc args handling
|
||||
let mut keys = fun.0.args.clone();
|
||||
let mut mapping = HashMap::new();
|
||||
for (key, value) in args {
|
||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
||||
for (maybe_key, value) in args {
|
||||
let key_str = if let Some(k) = maybe_key {
|
||||
let s = k.to_string();
|
||||
keys.retain(|p| p.name.to_string() != s);
|
||||
s
|
||||
} else {
|
||||
let removed = keys.remove(0).name.to_string();
|
||||
removed
|
||||
};
|
||||
mapping.insert(key_str, value);
|
||||
}
|
||||
// default value handling
|
||||
for k in keys {
|
||||
mapping
|
||||
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
|
||||
let key_str = k.name.to_string();
|
||||
if let Some(default_val) = k.default_value.as_ref() {
|
||||
mapping.insert(key_str, ctx.gen_symbol_val(generator, default_val, k.ty).into());
|
||||
} else {
|
||||
return Err(format!(
|
||||
"No argument provided for parameter '{}' and no default value exists",
|
||||
k.name
|
||||
));
|
||||
}
|
||||
}
|
||||
let mut real_params = Vec::new();
|
||||
for arg in &fun.0.args {
|
||||
let key_str = arg.name.to_string();
|
||||
let value = if let Some(val) = mapping.remove(&key_str) {
|
||||
val
|
||||
} else if let Some(default_val) = arg.default_value.as_ref() {
|
||||
ctx.gen_symbol_val(generator, default_val, arg.ty).into()
|
||||
} else {
|
||||
return Err(format!(
|
||||
"No argument provided for parameter '{}' and no default value exists",
|
||||
arg.name
|
||||
));
|
||||
};
|
||||
let llvm_val = value.to_basic_value_enum(ctx, generator, arg.ty)?;
|
||||
real_params.push((llvm_val, arg.ty));
|
||||
}
|
||||
// reorder the parameters
|
||||
let mut real_params = fun
|
||||
.0
|
||||
.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
mapping
|
||||
.remove(&arg.name)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, arg.ty)
|
||||
.map(|llvm_val| (llvm_val, arg.ty))
|
||||
})
|
||||
.collect::<Result<Vec<(_, _)>, _>>()?;
|
||||
if let Some(obj) = obj {
|
||||
if let ValueEnum::Static(obj_val) = obj.1 {
|
||||
real_params.insert(0, (obj_val.get_const_obj(ctx, generator), obj.0));
|
||||
} else {
|
||||
// should be an error here...
|
||||
panic!("only host object is allowed");
|
||||
return Err("Only host objects are allowed for 'self'".into());
|
||||
}
|
||||
}
|
||||
|
||||
for (i, (arg, arg_ty)) in real_params.iter().enumerate() {
|
||||
let arg_slot = format_rpc_arg(generator, ctx, (*arg, *arg_ty, i));
|
||||
let arg_ptr = unsafe {
|
||||
@ -933,8 +1009,6 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
.unwrap();
|
||||
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
||||
}
|
||||
|
||||
// call
|
||||
infer_and_call_function(
|
||||
ctx,
|
||||
if is_async { "rpc_send_async" } else { "rpc_send" },
|
||||
@ -944,20 +1018,15 @@ fn rpc_codegen_callback_fn<'ctx>(
|
||||
None,
|
||||
);
|
||||
|
||||
// reclaim stack space used by arguments
|
||||
call_stackrestore(ctx, stackptr);
|
||||
|
||||
if is_async {
|
||||
// async RPCs do not return any values
|
||||
Ok(None)
|
||||
} else {
|
||||
let result = format_rpc_ret(generator, ctx, fun.0.ret);
|
||||
|
||||
let result = format_rpc_ret(generator, ctx, fun.0.ret, is_async);
|
||||
if !result.is_some_and(|res| res.get_type().is_pointer_type()) {
|
||||
// An RPC returning an NDArray would not touch here.
|
||||
call_stackrestore(ctx, stackptr);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user