Implement RPC for NDArrays #426
|
@ -23,6 +23,7 @@ use pyo3::{
|
||||||
|
|
||||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||||
|
|
||||||
|
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
|
||||||
use std::{
|
use std::{
|
||||||
collections::hash_map::DefaultHasher,
|
collections::hash_map::DefaultHasher,
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
|
@ -397,6 +398,27 @@ fn gen_rpc_tag(
|
||||||
buffer.push(b'l');
|
buffer.push(b'l');
|
||||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||||
}
|
}
|
||||||
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
|
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
let ndarray_ndims = if let TLiteral { values, .. } =
|
||||||
|
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||||
|
{
|
||||||
|
if values.len() != 1 {
|
||||||
|
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = values[0].clone();
|
||||||
|
u64::try_from(value.clone())
|
||||||
|
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
|
||||||
|
|
||||||
|
buffer.push(b'a');
|
||||||
|
buffer.push((ndarray_ndims & 0xFF) as u8);
|
||||||
|
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
||||||
|
}
|
||||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue