Implement RPC for NDArrays #426

Merged
sb10q merged 5 commits from enhance/issue-149-ndarray/rpc into master 2024-06-19 22:25:50 +08:00
4 changed files with 193 additions and 11 deletions

View File

@ -23,6 +23,7 @@ use pyo3::{
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
use std::{
collections::hash_map::DefaultHasher,
collections::HashMap,
@ -397,6 +398,27 @@ fn gen_rpc_tag(
buffer.push(b'l');
gen_rpc_tag(ctx, *ty, buffer)?;
}
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let ndarray_ndims = if let TLiteral { values, .. } =
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
{
if values.len() != 1 {
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
}
let value = values[0].clone();
u64::try_from(value.clone())
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
} else {
unreachable!()
};
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
buffer.push(b'a');
buffer.push((ndarray_ndims & 0xFF) as u8);
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
}
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
}
}

View File

@ -899,7 +899,6 @@ impl Nac3 {
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap();
let numpy_typing_mod = PyModule::import(py, "numpy.typing").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
@ -930,7 +929,7 @@ impl Nac3 {
float: get_attr_id(builtins_mod, "float"),
float64: get_attr_id(numpy_mod, "float64"),
list: get_attr_id(builtins_mod, "list"),
ndarray: get_attr_id(numpy_typing_mod, "NDArray"),
ndarray: get_attr_id(numpy_mod, "ndarray"),
tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),

View File

@ -1,6 +1,14 @@
use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace};
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
};
use itertools::Itertools;
use nac3core::{
codegen::{CodeGenContext, CodeGenerator},
codegen::{
classes::{NDArrayType, ProxyType},
CodeGenContext, CodeGenerator,
},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
@ -485,7 +493,23 @@ impl InnerResolver {
)));
}
todo!()
// npt.NDArray[T] == np.ndarray[Any, np.dtype[T]]
let ndarray_dtype_pyty =
self.helper.args_ty_fn.call1(py, (args.get_item(1)?,))?;
let dtype = ndarray_dtype_pyty.downcast::<PyTuple>(py)?.get_item(0)?;
let ty = match self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)? {
Ok(ty) => ty,
Err(err) => return Ok(Err(err)),
};
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
return Ok(Err(
"type `ndarray` should take concrete parameters for dtype".into()
));
}
Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true)))
}
TypeEnum::TTuple { .. } => {
let args = match args
@ -670,7 +694,7 @@ impl InnerResolver {
}
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
let len: usize = obj.getattr("ndim")?.extract()?;
if len == 0 {
assert!(matches!(
&*unifier.get_ty(ty),
@ -679,10 +703,10 @@ impl InnerResolver {
));
Ok(Ok(extracted_ty))
} else {
let actual_ty =
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
match actual_ty {
Ok(t) => match unifier.unify(ty, t) {
let dtype = obj.getattr("dtype")?.getattr("type")?;
let dtype_ty = self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)?;
match dtype_ty {
Ok((t, _)) => match unifier.unify(ty, t) {
Ok(()) => {
let ndarray_ty =
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
@ -966,7 +990,143 @@ impl InnerResolver {
Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.ndarray {
todo!()
let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into()));
}
let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
{
expected_ty
} else {
unreachable!("must be ndarray")
};
let (ndarray_dtype, ndarray_ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
{
if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(),
Some(AddressSpace::default()),
&id_str,
)
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
}
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
else {
unreachable!("Expected Literal for ndarray_ndims")
};
let ndarray_ndims = if values.len() == 1 {
values[0].clone()
} else {
todo!("Unpacking literal of more than one element unimplemented")
};
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
unreachable!("Expected u64 value for ndarray_ndims")
};
// Obtain the shape of the ndarray
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
.iter()
.enumerate()
.map(|(i, elem)| {
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
)
})
.collect();
let shape_values = shape_values?.unwrap();
let shape_values = llvm_usize.const_array(
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
);
// create a global for ndarray.shape and initialize it using the shape
let shape_global = ctx.module.add_global(
llvm_usize.array_type(ndarray_ndims as u32),
Some(AddressSpace::default()),
&(id_str.clone() + ".shape"),
);
shape_global.set_initializer(&shape_values);
// Obtain the (flattened) elements of the ndarray
let sz: usize = obj.getattr("size")?.extract()?;
let data: Result<Option<Vec<_>>, _> = (0..sz)
.map(|i| {
obj.getattr("flat")?.get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
})
})
.collect();
let data = data?.unwrap().into_iter();
let data = match ndarray_dtype_llvm_ty {
BasicTypeEnum::ArrayType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
}
BasicTypeEnum::FloatType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec())
}
BasicTypeEnum::IntType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec())
}
BasicTypeEnum::PointerType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec())
}
BasicTypeEnum::StructType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec())
}
BasicTypeEnum::VectorType(_) => unreachable!(),
};
// create a global for ndarray.data and initialize it using the elements
let data_global = ctx.module.add_global(
ndarray_dtype_llvm_ty.array_type(sz as u32),
Some(AddressSpace::default()),
&(id_str.clone() + ".data"),
);
data_global.set_initializer(&data);
// create a global for the ndarray object and initialize it
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[
llvm_usize.const_int(ndarray_ndims, false).into(),
shape_global
.as_pointer_value()
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
.into(),
data_global
.as_pointer_value()
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
.into(),
]);
let ndarray = ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(),
Some(AddressSpace::default()),
&id_str,
);
ndarray.set_initializer(&value);
Ok(Some(ndarray.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };

View File

@ -1,3 +1,4 @@
*.bc
*.ll
*.o
/demo