artiq: WIP - Implement Python-to-LLVM conversion of ndarray

This commit is contained in:
David Mak 2024-06-14 14:48:29 +08:00
parent 676412fe6d
commit 4cffd3aa07
2 changed files with 175 additions and 4 deletions

View File

@ -64,7 +64,9 @@ use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback;
use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
symbol_resolver::{
DeferredEvaluationStore, InnerResolver, NumpyHelper, PythonHelper, Resolver,
},
};
mod codegen;
@ -329,6 +331,11 @@ impl Nac3 {
type_fn: builtins.getattr("type").unwrap().to_object(py),
origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py),
args_ty_fn: typings.getattr("get_args").unwrap().to_object(py),
np_helpers: NumpyHelper {
ndarray_shape: |obj| obj.getattr("shape").unwrap(),
ndarray_size_fn: |obj| obj.getattr("size").unwrap(),
ndarray_flat_fn: |obj| obj.getattr("flat").unwrap(),
},
store_obj: store_obj.clone(),
store_str,
};

View File

@ -1,6 +1,14 @@
use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace};
use inkwell::{
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
AddressSpace,
};
use itertools::Itertools;
use nac3core::{
codegen::{CodeGenContext, CodeGenerator},
codegen::{
classes::{NDArrayType, ProxyType},
CodeGenContext, CodeGenerator,
},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
helper::PrimDef,
@ -85,6 +93,20 @@ pub struct InnerResolver {
pub struct Resolver(pub Arc<InnerResolver>);
/// Helpers for invoking NumPy functions.
#[allow(clippy::struct_field_names)]
#[derive(Clone)]
pub struct NumpyHelper {
/// [`numpy.ndarray.shape`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html)
pub ndarray_shape: fn(&PyAny) -> &PyAny,
/// [`numpy.ndarray.size`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.size.html)
pub ndarray_size_fn: fn(&PyAny) -> &PyAny,
/// [`numpy.ndarray.flat`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flat.html)
pub ndarray_flat_fn: fn(&PyAny) -> &PyAny,
}
#[derive(Clone)]
pub struct PythonHelper {
pub type_fn: PyObject,
@ -92,6 +114,10 @@ pub struct PythonHelper {
pub id_fn: PyObject,
pub origin_ty_fn: PyObject,
pub args_ty_fn: PyObject,
/// See [`NumpyHelper`].
pub np_helpers: NumpyHelper,
pub store_obj: PyObject,
pub store_str: PyObject,
}
@ -958,7 +984,145 @@ impl InnerResolver {
Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.ndarray {
todo!()
let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into()));
}
let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
{
expected_ty
} else {
unreachable!("must be ndarray")
};
let (ndarray_dtype, ndarray_ndims) =
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
let llvm_usize = generator.get_size_type(ctx.ctx);
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
{
if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(),
Some(AddressSpace::default()),
&id_str,
)
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
}
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
else {
unreachable!("Expected Literal for ndarray_ndims")
};
let ndarray_ndims = if values.len() == 1 {
values[0].clone()
} else {
todo!("Unpacking literal of more than one element unimplemented")
};
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
unreachable!("Expected u64 value for ndarray_ndims")
};
// Obtain the shape of the ndarray
let shape_tuple = (self.helper.np_helpers.ndarray_shape)(obj);
let shape_tuple = shape_tuple.downcast::<PyTuple>()?;
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
.iter()
.enumerate()
.map(|(i, elem)| {
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
)
})
.collect();
let shape_values = shape_values?.unwrap();
let shape_values = llvm_usize.const_array(
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
);
// create a global for ndarray.shape and initialize it using the shape
let shape_global = ctx.module.add_global(
llvm_usize.array_type(ndarray_ndims as u32),
Some(AddressSpace::default()),
&(id_str.clone() + ".shape"),
);
shape_global.set_initializer(&shape_values);
// Obtain the (flattened) elements of the ndarray
let sz = (self.helper.np_helpers.ndarray_size_fn)(obj);
let sz = sz.extract::<usize>()?;
let data: Result<Option<Vec<_>>, _> = (0..sz)
.map(|i| {
(self.helper.np_helpers.ndarray_flat_fn)(obj).get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
})
})
.collect();
let data = data?.unwrap().into_iter();
let data = match ndarray_dtype_llvm_ty {
BasicTypeEnum::ArrayType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
}
BasicTypeEnum::FloatType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec())
}
BasicTypeEnum::IntType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec())
}
BasicTypeEnum::PointerType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec())
}
BasicTypeEnum::StructType(ty) => {
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec())
}
BasicTypeEnum::VectorType(_) => unreachable!(),
};
// create a global for ndarray.data and initialize it using the elements
let data_global = ctx.module.add_global(
ndarray_dtype_llvm_ty.array_type(sz as u32),
Some(AddressSpace::default()),
&(id_str.clone() + ".data"),
);
data_global.set_initializer(&data);
// create a global for the ndarray object and initialize it
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[
llvm_usize.const_int(ndarray_ndims, false).into(),
shape_global
.as_pointer_value()
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
.into(),
data_global
.as_pointer_value()
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
.into(),
]);
let ndarray = ctx.module.add_global(
ndarray_llvm_ty.as_underlying_type(),
Some(AddressSpace::default()),
&id_str,
);
ndarray.set_initializer(&value);
Ok(Some(ndarray.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };