From d3cd2a8d99ece0837b24cb2358a42f140429e09c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 17 Jun 2024 15:01:22 +0800 Subject: [PATCH] artiq: Add support for generating RPC tag for ndarray --- nac3artiq/src/codegen.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index b617eeae..7e64dc1d 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -23,6 +23,7 @@ use pyo3::{ use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; +use nac3core::toplevel::numpy::unpack_ndarray_var_tys; use std::{ collections::hash_map::DefaultHasher, collections::HashMap, @@ -397,6 +398,27 @@ fn gen_rpc_tag( buffer.push(b'l'); gen_rpc_tag(ctx, *ty, buffer)?; } + TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + let ndarray_ndims = if let TLiteral { values, .. } = + &*ctx.unifier.get_ty_immutable(ndarray_ndims) + { + if values.len() != 1 { + return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty))); + } + + let value = values[0].clone(); + u64::try_from(value.clone()) + .map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))? + } else { + unreachable!() + }; + assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims)); + + buffer.push(b'a'); + buffer.push((ndarray_ndims & 0xFF) as u8); + gen_rpc_tag(ctx, ndarray_dtype, buffer)?; + } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } }