From 90134206c41183348496a3ebadc5c56c5fc85cf8 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 14 Jun 2024 14:48:29 +0800 Subject: [PATCH] artiq: Implement Python-to-LLVM conversion of ndarray --- nac3artiq/src/symbol_resolver.rs | 160 +++++++++++++++++++++++++++++-- 1 file changed, 152 insertions(+), 8 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index aa7f2a4..e06ed4d 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -1,6 +1,14 @@ -use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; +use inkwell::{ + types::{BasicType, BasicTypeEnum}, + values::BasicValueEnum, + AddressSpace, +}; +use itertools::Itertools; use nac3core::{ - codegen::{CodeGenContext, CodeGenerator}, + codegen::{ + classes::{NDArrayType, ProxyType}, + CodeGenContext, CodeGenerator, + }, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{ helper::PrimDef, @@ -670,7 +678,7 @@ impl InnerResolver { } (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); - let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; + let len: usize = obj.getattr("ndim")?.extract()?; if len == 0 { assert!(matches!( &*unifier.get_ty(ty), @@ -679,10 +687,10 @@ impl InnerResolver { )); Ok(Ok(extracted_ty)) } else { - let actual_ty = - self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; - match actual_ty { - Ok(t) => match unifier.unify(ty, t) { + let dtype = obj.getattr("dtype")?.getattr("type")?; + let dtype_ty = self.get_pyty_obj_type(py, dtype, unifier, defs, primitives)?; + match dtype_ty { + Ok((t, _)) => match unifier.unify(ty, t) { Ok(()) => { let ndarray_ty = make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims)); @@ -966,7 +974,143 @@ impl InnerResolver { Ok(Some(global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.ndarray { - todo!() + let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id()) + { + expected_ty + } else { + unreachable!("must be ndarray") + }; + let (ndarray_dtype, ndarray_ndims) = + unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); + + let llvm_usize = generator.get_size_type(ctx.ctx); + let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); + let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global( + ndarray_llvm_ty.as_underlying_type(), + Some(AddressSpace::default()), + &id_str, + ) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + + let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndarray_ndims) + else { + unreachable!("Expected Literal for ndarray_ndims") + }; + + let ndarray_ndims = if values.len() == 1 { + values[0].clone() + } else { + todo!("Unpacking literal of more than one element unimplemented") + }; + let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { + unreachable!("Expected u64 value for ndarray_ndims") + }; + + // Obtain the shape of the ndarray + let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; + assert_eq!(shape_tuple.len(), ndarray_ndims as usize); + let shape_values: Result>, _> = shape_tuple + .iter() + .enumerate() + .map(|(i, elem)| { + self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( + |e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), + ) + }) + .collect(); + let shape_values = shape_values?.unwrap(); + let shape_values = llvm_usize.const_array( + &shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), + ); + + // create a global for ndarray.shape and initialize it using the shape + let shape_global = ctx.module.add_global( + llvm_usize.array_type(ndarray_ndims as u32), + Some(AddressSpace::default()), + &(id_str.clone() + ".shape"), + ); + shape_global.set_initializer(&shape_values); + + // Obtain the (flattened) elements of the ndarray + let sz: usize = obj.getattr("size")?.extract()?; + let data: Result>, _> = (0..sz) + .map(|i| { + obj.getattr("flat")?.get_item(i).and_then(|elem| { + self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + }) + }) + }) + .collect(); + let data = data?.unwrap().into_iter(); + let data = match ndarray_dtype_llvm_ty { + BasicTypeEnum::ArrayType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) + } + + BasicTypeEnum::FloatType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) + } + + BasicTypeEnum::IntType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) + } + + BasicTypeEnum::PointerType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) + } + + BasicTypeEnum::StructType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) + } + + BasicTypeEnum::VectorType(_) => unreachable!(), + }; + + // create a global for ndarray.data and initialize it using the elements + let data_global = ctx.module.add_global( + ndarray_dtype_llvm_ty.array_type(sz as u32), + Some(AddressSpace::default()), + &(id_str.clone() + ".data"), + ); + data_global.set_initializer(&data); + + // create a global for the ndarray object and initialize it + let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ + llvm_usize.const_int(ndarray_ndims, false).into(), + shape_global + .as_pointer_value() + .const_cast(llvm_usize.ptr_type(AddressSpace::default())) + .into(), + data_global + .as_pointer_value() + .const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default())) + .into(), + ]); + + let ndarray = ctx.module.add_global( + ndarray_llvm_ty.as_underlying_type(), + Some(AddressSpace::default()), + &id_str, + ); + ndarray.set_initializer(&value); + + Ok(Some(ndarray.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };