Implement RPC for NDArrays #426
|
@ -23,6 +23,7 @@ use pyo3::{
|
|||
|
||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
use nac3core::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
collections::HashMap,
|
||||
|
@ -397,6 +398,27 @@ fn gen_rpc_tag(
|
|||
buffer.push(b'l');
|
||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||
}
|
||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
let ndarray_ndims = if let TLiteral { values, .. } =
|
||||
&*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||
{
|
||||
if values.len() != 1 {
|
||||
return Err(format!("NDArray types with multiple literal bounds for ndims is not supported: {}", ctx.unifier.stringify(ty)));
|
||||
}
|
||||
|
||||
let value = values[0].clone();
|
||||
u64::try_from(value.clone())
|
||||
.map_err(|()| format!("Expected u64 for ndarray.ndims, got {value}"))?
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
assert!((0u64..=u64::from(u8::MAX)).contains(&ndarray_ndims));
|
||||
|
||||
buffer.push(b'a');
|
||||
buffer.push((ndarray_ndims & 0xFF) as u8);
|
||||
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
||||
}
|
||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -899,7 +899,6 @@ impl Nac3 {
|
|||
let builtins_mod = PyModule::import(py, "builtins").unwrap();
|
||||
let id_fn = builtins_mod.getattr("id").unwrap();
|
||||
let numpy_mod = PyModule::import(py, "numpy").unwrap();
|
||||
let numpy_typing_mod = PyModule::import(py, "numpy.typing").unwrap();
|
||||
let typing_mod = PyModule::import(py, "typing").unwrap();
|
||||
let types_mod = PyModule::import(py, "types").unwrap();
|
||||
|
||||
|
@ -930,7 +929,7 @@ impl Nac3 {
|
|||
float: get_attr_id(builtins_mod, "float"),
|
||||
float64: get_attr_id(numpy_mod, "float64"),
|
||||
list: get_attr_id(builtins_mod, "list"),
|
||||
ndarray: get_attr_id(numpy_typing_mod, "NDArray"),
|
||||
ndarray: get_attr_id(numpy_mod, "ndarray"),
|
||||
tuple: get_attr_id(builtins_mod, "tuple"),
|
||||
exception: get_attr_id(builtins_mod, "Exception"),
|
||||
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace};
|
||||
use inkwell::{
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::BasicValueEnum,
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use nac3core::{
|
||||
codegen::{CodeGenContext, CodeGenerator},
|
||||
codegen::{
|
||||
classes::{NDArrayType, ProxyType},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
|
@ -485,7 +493,23 @@ impl InnerResolver {
|
|||
)));
|
||||
}
|
||||
|
||||
todo!()
|
||||
// npt.NDArray[T] == np.ndarray[Any, np.dtype[T]]
|
||||
let ndarray_dtype_pyty =
|
||||
self.helper.args_ty_fn.call1(py, (args.get_item(1)?,))?;
|
||||
let dtype = ndarray_dtype_pyty.downcast::<PyTuple>(py)?.get_item(0)?;
|
||||
|
||||
let ty = match self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)? {
|
||||
Ok(ty) => ty,
|
||||
Err(err) => return Ok(Err(err)),
|
||||
};
|
||||
|
||||
if !unifier.is_concrete(ty.0, &[]) && !ty.1 {
|
||||
return Ok(Err(
|
||||
"type `ndarray` should take concrete parameters for dtype".into()
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Ok((make_ndarray_ty(unifier, primitives, Some(ty.0), None), true)))
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let args = match args
|
||||
|
@ -670,7 +694,7 @@ impl InnerResolver {
|
|||
}
|
||||
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
||||
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||
let len: usize = obj.getattr("ndim")?.extract()?;
|
||||
if len == 0 {
|
||||
assert!(matches!(
|
||||
&*unifier.get_ty(ty),
|
||||
|
@ -679,10 +703,10 @@ impl InnerResolver {
|
|||
));
|
||||
Ok(Ok(extracted_ty))
|
||||
} else {
|
||||
let actual_ty =
|
||||
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
||||
match actual_ty {
|
||||
Ok(t) => match unifier.unify(ty, t) {
|
||||
let dtype = obj.getattr("dtype")?.getattr("type")?;
|
||||
let dtype_ty = self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)?;
|
||||
match dtype_ty {
|
||||
Ok((t, _)) => match unifier.unify(ty, t) {
|
||||
Ok(()) => {
|
||||
let ndarray_ty =
|
||||
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
|
||||
|
@ -966,7 +990,143 @@ impl InnerResolver {
|
|||
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
} else if ty_id == self.primitive_ids.ndarray {
|
||||
todo!()
|
||||
let id_str = id.to_string();
|
||||
|
||||
if let Some(global) = ctx.module.get_global(&id_str) {
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
}
|
||||
|
||||
let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
|
||||
{
|
||||
expected_ty
|
||||
} else {
|
||||
unreachable!("must be ndarray")
|
||||
};
|
||||
let (ndarray_dtype, ndarray_ndims) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
|
||||
|
||||
{
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(
|
||||
ndarray_llvm_ty.as_underlying_type(),
|
||||
Some(AddressSpace::default()),
|
||||
&id_str,
|
||||
)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
}
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
|
||||
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||
else {
|
||||
unreachable!("Expected Literal for ndarray_ndims")
|
||||
};
|
||||
|
||||
let ndarray_ndims = if values.len() == 1 {
|
||||
values[0].clone()
|
||||
} else {
|
||||
todo!("Unpacking literal of more than one element unimplemented")
|
||||
};
|
||||
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
|
||||
unreachable!("Expected u64 value for ndarray_ndims")
|
||||
};
|
||||
|
||||
// Obtain the shape of the ndarray
|
||||
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
||||
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
|
||||
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, elem)| {
|
||||
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|
||||
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let shape_values = shape_values?.unwrap();
|
||||
let shape_values = llvm_usize.const_array(
|
||||
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
|
||||
);
|
||||
|
||||
// create a global for ndarray.shape and initialize it using the shape
|
||||
let shape_global = ctx.module.add_global(
|
||||
llvm_usize.array_type(ndarray_ndims as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&(id_str.clone() + ".shape"),
|
||||
);
|
||||
shape_global.set_initializer(&shape_values);
|
||||
|
||||
// Obtain the (flattened) elements of the ndarray
|
||||
let sz: usize = obj.getattr("size")?.extract()?;
|
||||
let data: Result<Option<Vec<_>>, _> = (0..sz)
|
||||
.map(|i| {
|
||||
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
||||
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
|
||||
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||
})
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let data = data?.unwrap().into_iter();
|
||||
let data = match ndarray_dtype_llvm_ty {
|
||||
BasicTypeEnum::ArrayType(ty) => {
|
||||
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
|
||||
}
|
||||
|
||||
BasicTypeEnum::FloatType(ty) => {
|
||||
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec())
|
||||
}
|
||||
|
||||
BasicTypeEnum::IntType(ty) => {
|
||||
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec())
|
||||
}
|
||||
|
||||
BasicTypeEnum::PointerType(ty) => {
|
||||
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec())
|
||||
}
|
||||
|
||||
BasicTypeEnum::StructType(ty) => {
|
||||
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec())
|
||||
}
|
||||
|
||||
BasicTypeEnum::VectorType(_) => unreachable!(),
|
||||
};
|
||||
|
||||
// create a global for ndarray.data and initialize it using the elements
|
||||
let data_global = ctx.module.add_global(
|
||||
ndarray_dtype_llvm_ty.array_type(sz as u32),
|
||||
Some(AddressSpace::default()),
|
||||
&(id_str.clone() + ".data"),
|
||||
);
|
||||
data_global.set_initializer(&data);
|
||||
|
||||
// create a global for the ndarray object and initialize it
|
||||
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[
|
||||
llvm_usize.const_int(ndarray_ndims, false).into(),
|
||||
shape_global
|
||||
.as_pointer_value()
|
||||
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
|
||||
.into(),
|
||||
data_global
|
||||
.as_pointer_value()
|
||||
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
|
||||
.into(),
|
||||
]);
|
||||
|
||||
let ndarray = ctx.module.add_global(
|
||||
ndarray_llvm_ty.as_underlying_type(),
|
||||
Some(AddressSpace::default()),
|
||||
&id_str,
|
||||
);
|
||||
ndarray.set_initializer(&value);
|
||||
|
||||
Ok(Some(ndarray.as_pointer_value().into()))
|
||||
} else if ty_id == self.primitive_ids.tuple {
|
||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
*.bc
|
||||
*.ll
|
||||
*.o
|
||||
/demo
|
||||
|
|
Loading…
Reference in New Issue