From 4cffd3aa07ef3f2d0a5d207011a319481c396fab Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 14 Jun 2024 14:48:29 +0800 Subject: [PATCH] artiq: WIP - Implement Python-to-LLVM conversion of ndarray --- nac3artiq/src/lib.rs | 9 +- nac3artiq/src/symbol_resolver.rs | 170 ++++++++++++++++++++++++++++++- 2 files changed, 175 insertions(+), 4 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 04344e2..c4777c5 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -64,7 +64,9 @@ use tempfile::{self, TempDir}; use crate::codegen::attributes_writeback; use crate::{ codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, - symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver}, + symbol_resolver::{ + DeferredEvaluationStore, InnerResolver, NumpyHelper, PythonHelper, Resolver, + }, }; mod codegen; @@ -329,6 +331,11 @@ impl Nac3 { type_fn: builtins.getattr("type").unwrap().to_object(py), origin_ty_fn: typings.getattr("get_origin").unwrap().to_object(py), args_ty_fn: typings.getattr("get_args").unwrap().to_object(py), + np_helpers: NumpyHelper { + ndarray_shape: |obj| obj.getattr("shape").unwrap(), + ndarray_size_fn: |obj| obj.getattr("size").unwrap(), + ndarray_flat_fn: |obj| obj.getattr("flat").unwrap(), + }, store_obj: store_obj.clone(), store_str, }; diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 0b9ede9..1cde311 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1,6 +1,14 @@ -use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; +use inkwell::{ + types::{BasicType, BasicTypeEnum}, + values::BasicValueEnum, + AddressSpace, +}; +use itertools::Itertools; use nac3core::{ - codegen::{CodeGenContext, CodeGenerator}, + codegen::{ + classes::{NDArrayType, ProxyType}, + CodeGenContext, CodeGenerator, + }, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{ helper::PrimDef, @@ -85,6 +93,20 @@ pub struct InnerResolver { pub struct Resolver(pub Arc); +/// Helpers for invoking NumPy functions. +#[allow(clippy::struct_field_names)] +#[derive(Clone)] +pub struct NumpyHelper { + /// [`numpy.ndarray.shape`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html) + pub ndarray_shape: fn(&PyAny) -> &PyAny, + + /// [`numpy.ndarray.size`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.size.html) + pub ndarray_size_fn: fn(&PyAny) -> &PyAny, + + /// [`numpy.ndarray.flat`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flat.html) + pub ndarray_flat_fn: fn(&PyAny) -> &PyAny, +} + #[derive(Clone)] pub struct PythonHelper { pub type_fn: PyObject, @@ -92,6 +114,10 @@ pub struct PythonHelper { pub id_fn: PyObject, pub origin_ty_fn: PyObject, pub args_ty_fn: PyObject, + + /// See [`NumpyHelper`]. + pub np_helpers: NumpyHelper, + pub store_obj: PyObject, pub store_str: PyObject, } @@ -958,7 +984,145 @@ impl InnerResolver { Ok(Some(global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.ndarray { - todo!() + let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id()) + { + expected_ty + } else { + unreachable!("must be ndarray") + }; + let (ndarray_dtype, ndarray_ndims) = + unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); + + let llvm_usize = generator.get_size_type(ctx.ctx); + let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); + let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global( + ndarray_llvm_ty.as_underlying_type(), + Some(AddressSpace::default()), + &id_str, + ) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + + let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims) + else { + unreachable!("Expected Literal for ndarray_ndims") + }; + + let ndarray_ndims = if values.len() == 1 { + values[0].clone() + } else { + todo!("Unpacking literal of more than one element unimplemented") + }; + let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { + unreachable!("Expected u64 value for ndarray_ndims") + }; + + // Obtain the shape of the ndarray + let shape_tuple = (self.helper.np_helpers.ndarray_shape)(obj); + let shape_tuple = shape_tuple.downcast::()?; + assert_eq!(shape_tuple.len(), ndarray_ndims as usize); + let shape_values: Result>, _> = shape_tuple + .iter() + .enumerate() + .map(|(i, elem)| { + self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( + |e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), + ) + }) + .collect(); + let shape_values = shape_values?.unwrap(); + let shape_values = llvm_usize.const_array( + &shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), + ); + + // create a global for ndarray.shape and initialize it using the shape + let shape_global = ctx.module.add_global( + llvm_usize.array_type(ndarray_ndims as u32), + Some(AddressSpace::default()), + &(id_str.clone() + ".shape"), + ); + shape_global.set_initializer(&shape_values); + + // Obtain the (flattened) elements of the ndarray + let sz = (self.helper.np_helpers.ndarray_size_fn)(obj); + let sz = sz.extract::()?; + let data: Result>, _> = (0..sz) + .map(|i| { + (self.helper.np_helpers.ndarray_flat_fn)(obj).get_item(i).and_then(|elem| { + self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + }) + }) + }) + .collect(); + let data = data?.unwrap().into_iter(); + let data = match ndarray_dtype_llvm_ty { + BasicTypeEnum::ArrayType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) + } + + BasicTypeEnum::FloatType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) + } + + BasicTypeEnum::IntType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) + } + + BasicTypeEnum::PointerType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) + } + + BasicTypeEnum::StructType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) + } + + BasicTypeEnum::VectorType(_) => unreachable!(), + }; + + // create a global for ndarray.data and initialize it using the elements + let data_global = ctx.module.add_global( + ndarray_dtype_llvm_ty.array_type(sz as u32), + Some(AddressSpace::default()), + &(id_str.clone() + ".data"), + ); + data_global.set_initializer(&data); + + // create a global for the ndarray object and initialize it + let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ + llvm_usize.const_int(ndarray_ndims, false).into(), + shape_global + .as_pointer_value() + .const_cast(llvm_usize.ptr_type(AddressSpace::default())) + .into(), + data_global + .as_pointer_value() + .const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default())) + .into(), + ]); + + let ndarray = ctx.module.add_global( + ndarray_llvm_ty.as_underlying_type(), + Some(AddressSpace::default()), + &id_str, + ); + ndarray.set_initializer(&value); + + Ok(Some(ndarray.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };