Implement RPC for NDArrays #426
|
@ -23,6 +23,7 @@ use pyo3::{
|
|||
|
||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
collections::HashMap,
|
||||
|
@ -397,6 +398,27 @@ fn gen_rpc_tag(
|
|||
buffer.push(b'l');
|
||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||
}
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
let ndarray_ndims = if let TLiteral { values, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||
{
|
||||
if values.len() != 1 {
|
||||
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
|
||||
}
|
||||
|
||||
let value = values[0].clone();
|
||||
u64::try_from(value.clone())
|
||||
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
|
||||
|
||||
buffer.push(b'a');
|
||||
buffer.push((ndarray_ndims & 0xFF) as u8);
|
||||
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
||||
}
|
||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue