diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index b617eea..7e64dc1 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -23,6 +23,7 @@ use pyo3::{ use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; +use nac3core::toplevel::numpy::unpack_ndarray_var_tys; use std::{ collections::hash_map::DefaultHasher, collections::HashMap, @@ -397,6 +398,27 @@ fn gen_rpc_tag( buffer.push(b'l'); gen_rpc_tag(ctx, *ty, buffer)?; } + TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let ndarray_ndims = if let TLiteral { values, .. } = + &*ctx.unifier.get_ty_immutable(ndarray_ndims) + { + if values.len() != 1 { + return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty))); + } + + let value = values[0].clone(); + u64::try_from(value.clone()) + .map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))? + } else { + unreachable!() + }; + assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims)); + + buffer.push(b'a'); + buffer.push((ndarray_ndims & 0xFF) as u8); + gen_rpc_tag(ctx, ndarray_dtype, buffer)?; + } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } }