forked from M-Labs/nac3
artiq: Implement Python-to-LLVM conversion of ndarray
This commit is contained in:
parent
8d9df0a615
commit
76dd5191f5
@ -1,6 +1,14 @@
|
|||||||
use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace};
|
use inkwell::{
|
||||||
|
types::{BasicType, BasicTypeEnum},
|
||||||
|
values::BasicValueEnum,
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
use nac3core::{
|
use nac3core::{
|
||||||
codegen::{CodeGenContext, CodeGenerator},
|
codegen::{
|
||||||
|
classes::{NDArrayType, ProxyType},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
helper::PrimDef,
|
helper::PrimDef,
|
||||||
@ -670,7 +678,7 @@ impl InnerResolver {
|
|||||||
}
|
}
|
||||||
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
|
||||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
let len: usize = obj.getattr("ndim")?.extract()?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
&*unifier.get_ty(ty),
|
&*unifier.get_ty(ty),
|
||||||
@ -679,10 +687,10 @@ impl InnerResolver {
|
|||||||
));
|
));
|
||||||
Ok(Ok(extracted_ty))
|
Ok(Ok(extracted_ty))
|
||||||
} else {
|
} else {
|
||||||
let actual_ty =
|
let dtype = obj.getattr("dtype")?.getattr("type")?;
|
||||||
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
let dtype_ty = self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)?;
|
||||||
match actual_ty {
|
match dtype_ty {
|
||||||
Ok(t) => match unifier.unify(ty, t) {
|
Ok((t, _)) => match unifier.unify(ty, t) {
|
||||||
Ok(()) => {
|
Ok(()) => {
|
||||||
let ndarray_ty =
|
let ndarray_ty =
|
||||||
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
|
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
|
||||||
@ -966,7 +974,143 @@ impl InnerResolver {
|
|||||||
|
|
||||||
Ok(Some(global.as_pointer_value().into()))
|
Ok(Some(global.as_pointer_value().into()))
|
||||||
} else if ty_id == self.primitive_ids.ndarray {
|
} else if ty_id == self.primitive_ids.ndarray {
|
||||||
todo!()
|
let id_str = id.to_string();
|
||||||
|
|
||||||
|
if let Some(global) = ctx.module.get_global(&id_str) {
|
||||||
|
return Ok(Some(global.as_pointer_value().into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id())
|
||||||
|
{
|
||||||
|
expected_ty
|
||||||
|
} else {
|
||||||
|
unreachable!("must be ndarray")
|
||||||
|
};
|
||||||
|
let (ndarray_dtype, ndarray_ndims) =
|
||||||
|
unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty);
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype);
|
||||||
|
let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty);
|
||||||
|
|
||||||
|
{
|
||||||
|
if self.global_value_ids.read().contains_key(&id) {
|
||||||
|
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||||
|
ctx.module.add_global(
|
||||||
|
ndarray_llvm_ty.as_underlying_type(),
|
||||||
|
Some(AddressSpace::default()),
|
||||||
|
&id_str,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
return Ok(Some(global.as_pointer_value().into()));
|
||||||
|
}
|
||||||
|
self.global_value_ids.write().insert(id, obj.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims)
|
||||||
|
else {
|
||||||
|
unreachable!("Expected Literal for ndarray_ndims")
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_ndims = if values.len() == 1 {
|
||||||
|
values[0].clone()
|
||||||
|
} else {
|
||||||
|
todo!("Unpacking literal of more than one element unimplemented")
|
||||||
|
};
|
||||||
|
let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else {
|
||||||
|
unreachable!("Expected u64 value for ndarray_ndims")
|
||||||
|
};
|
||||||
|
|
||||||
|
// Obtain the shape of the ndarray
|
||||||
|
let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?;
|
||||||
|
assert_eq!(shape_tuple.len(), ndarray_ndims as usize);
|
||||||
|
let shape_values: Result<Option<Vec<_>>, _> = shape_tuple
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, elem)| {
|
||||||
|
self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err(
|
||||||
|
|e| super::CompileError::new_err(format!("Error getting element {i}: {e}")),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let shape_values = shape_values?.unwrap();
|
||||||
|
let shape_values = llvm_usize.const_array(
|
||||||
|
&shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// create a global for ndarray.shape and initialize it using the shape
|
||||||
|
let shape_global = ctx.module.add_global(
|
||||||
|
llvm_usize.array_type(ndarray_ndims as u32),
|
||||||
|
Some(AddressSpace::default()),
|
||||||
|
&(id_str.clone() + ".shape"),
|
||||||
|
);
|
||||||
|
shape_global.set_initializer(&shape_values);
|
||||||
|
|
||||||
|
// Obtain the (flattened) elements of the ndarray
|
||||||
|
let sz: usize = obj.getattr("size")?.extract()?;
|
||||||
|
let data: Result<Option<Vec<_>>, _> = (0..sz)
|
||||||
|
.map(|i| {
|
||||||
|
obj.getattr("flat")?.get_item(i).and_then(|elem| {
|
||||||
|
self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| {
|
||||||
|
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let data = data?.unwrap().into_iter();
|
||||||
|
let data = match ndarray_dtype_llvm_ty {
|
||||||
|
BasicTypeEnum::ArrayType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::FloatType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::IntType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::PointerType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::StructType(ty) => {
|
||||||
|
ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec())
|
||||||
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::VectorType(_) => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// create a global for ndarray.data and initialize it using the elements
|
||||||
|
let data_global = ctx.module.add_global(
|
||||||
|
ndarray_dtype_llvm_ty.array_type(sz as u32),
|
||||||
|
Some(AddressSpace::default()),
|
||||||
|
&(id_str.clone() + ".data"),
|
||||||
|
);
|
||||||
|
data_global.set_initializer(&data);
|
||||||
|
|
||||||
|
// create a global for the ndarray object and initialize it
|
||||||
|
let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[
|
||||||
|
llvm_usize.const_int(ndarray_ndims, false).into(),
|
||||||
|
shape_global
|
||||||
|
.as_pointer_value()
|
||||||
|
.const_cast(llvm_usize.ptr_type(AddressSpace::default()))
|
||||||
|
.into(),
|
||||||
|
data_global
|
||||||
|
.as_pointer_value()
|
||||||
|
.const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default()))
|
||||||
|
.into(),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let ndarray = ctx.module.add_global(
|
||||||
|
ndarray_llvm_ty.as_underlying_type(),
|
||||||
|
Some(AddressSpace::default()),
|
||||||
|
&id_str,
|
||||||
|
);
|
||||||
|
ndarray.set_initializer(&value);
|
||||||
|
|
||||||
|
Ok(Some(ndarray.as_pointer_value().into()))
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||||
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
|
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
|
||||||
|
Loading…
Reference in New Issue
Block a user