From 44c49dc102ef5d4e3cad1db2a49c34413fc55b33 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 29 Nov 2024 16:40:40 +0800 Subject: [PATCH] [artiq] codegen: Reimplement polymorphic_print for strided ndarray Based on 2a6ee503: artiq: reimplement polymorphic_print for ndarray --- nac3artiq/src/codegen.rs | 233 ++++++++++++++++++--------------------- 1 file changed, 109 insertions(+), 124 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index be68e104..4b5eb417 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -12,11 +12,12 @@ use pyo3::{ PyObject, PyResult, Python, }; +use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; use nac3core::{ codegen::{ expr::{destructure_range, gen_call}, irrt::ndarray::call_ndarray_calc_size, - llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, + llvm_intrinsics::{call_int_smax, call_memcpy, call_stackrestore, call_stacksave}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, type_aligned_alloca, types::ndarray::NDArrayType, @@ -43,8 +44,6 @@ use nac3core::{ typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, }; -use super::{symbol_resolver::InnerResolver, timeline::TimeFns}; - /// The parallelism mode within a block. #[derive(Copy, Clone, Eq, PartialEq)] enum ParallelMode { @@ -465,55 +464,47 @@ fn format_rpc_arg<'ctx>( let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); let ndims = extract_ndims(&ctx.unifier, ndims); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty, Some(ndims)); - let llvm_arg = llvm_arg_ty.map_value(arg.into_pointer_value(), None); + let dtype = ctx.get_llvm_type(generator, elem_ty); + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype, Some(ndims)) + .map_value(arg.into_pointer_value(), None); - let llvm_usize_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast( - llvm_arg.get_type().size_type().size_of(), - llvm_usize, - "", - ) - .unwrap(); - let llvm_pdata_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast( - llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(), - llvm_usize, - "", - ) - .unwrap(); + let ndims = llvm_usize.const_int(ndims, false); - let dims_buf_sz = - ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for + // the reader. + // Turning it into a ContiguousNDArray to get a `data` that is contiguous. + let carray = ndarray.make_contiguous_ndarray(generator, ctx); - let buffer_size = - ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); + let sizeof_usize = llvm_usize.size_of(); + let sizeof_usize = + ctx.builder.build_int_z_extend_or_bit_cast(sizeof_usize, llvm_usize, "").unwrap(); - let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); - let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); + let sizeof_pdata = dtype.ptr_type(AddressSpace::default()).size_of(); + let sizeof_pdata = + ctx.builder.build_int_z_extend_or_bit_cast(sizeof_pdata, llvm_usize, "").unwrap(); - call_memcpy_generic( - ctx, - buffer.base_ptr(ctx, generator), - llvm_arg.ptr_to_data(ctx), - llvm_pdata_sizeof, - llvm_i1.const_zero(), - ); + let sizeof_buf_shape = ctx.builder.build_int_mul(sizeof_usize, ndims, "").unwrap(); + let sizeof_buf = ctx.builder.build_int_add(sizeof_buf_shape, sizeof_pdata, "").unwrap(); - let pbuffer_dims_begin = - unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; - call_memcpy_generic( - ctx, - pbuffer_dims_begin, - llvm_arg.shape().base_ptr(ctx, generator), - dims_buf_sz, - llvm_i1.const_zero(), - ); + // buf = { data: void*, shape: [size_t; ndims]; } + let buf = ctx.builder.build_array_alloca(llvm_i8, sizeof_buf, "rpc.arg").unwrap(); + let buf = ArraySliceValue::from_ptr_val(buf, sizeof_buf, Some("rpc.arg")); + let buf_data = buf.base_ptr(ctx, generator); + let buf_shape = + unsafe { buf.ptr_offset_unchecked(ctx, generator, &sizeof_pdata, None) }; - buffer.base_ptr(ctx, generator) + // Write to `buf->data` + let carray_data = carray.load_data(ctx); + let carray_data = ctx.builder.build_pointer_cast(carray_data, llvm_pi8, "").unwrap(); + call_memcpy(ctx, buf_data, carray_data, sizeof_pdata, llvm_i1.const_zero()); + + // Write to `buf->shape` + let carray_shape = ndarray.shape().base_ptr(ctx, generator); + let carray_shape_i8 = + ctx.builder.build_pointer_cast(carray_shape, llvm_pi8, "").unwrap(); + call_memcpy(ctx, buf_shape, carray_shape_i8, sizeof_buf_shape, llvm_i1.const_zero()); + + buf.base_ptr(ctx, generator) } _ => { @@ -554,6 +545,8 @@ fn format_rpc_ret<'ctx>( let llvm_i32 = ctx.ctx.i32_type(); let llvm_i8_8 = ctx.ctx.struct_type(&[llvm_i8.array_type(8).into()], false); let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); let rpc_recv = ctx.module.get_function("rpc_recv").unwrap_or_else(|| { ctx.module.add_function("rpc_recv", llvm_i32.fn_type(&[llvm_pi8.into()], false), None) @@ -574,8 +567,7 @@ fn format_rpc_ret<'ctx>( let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let num_0 = llvm_usize.const_zero(); // Round `val` up to its modulo `power_of_two` let round_up = |ctx: &mut CodeGenContext<'ctx, '_>, @@ -601,56 +593,49 @@ fn format_rpc_ret<'ctx>( .unwrap() }; - // Setup types - let llvm_ret_ty = NDArrayType::from_unifier_type(generator, ctx, ret_ty); - let llvm_elem_ty = llvm_ret_ty.element_type(); - // Allocate the resulting ndarray // A condition after format_rpc_ret ensures this will not be popped this off. - let ndarray = llvm_ret_ty.alloca(generator, ctx, Some("rpc.result")); + let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); + let dtype_llvm = ctx.get_llvm_type(generator, dtype); + let ndims = extract_ndims(&ctx.unifier, ndims); + let ndarray = NDArrayType::new(generator, ctx.ctx, dtype_llvm, Some(ndims)) + .construct_uninitialized(generator, ctx, None); - // Setup ndims - let ndims = llvm_ret_ty.ndims().unwrap(); - // Set `ndarray.ndims` - ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); - // Allocate `ndarray.shape` [size_t; ndims] - ndarray.create_shape(ctx, llvm_usize, ndarray.load_ndims(ctx)); + // NOTE: Current content of `ndarray`: + // - * `data` - **NOT YET** allocated. + // - * `itemsize` - initialized to be size_of(dtype). + // - * `ndims` - initialized. + // - * `shape` - allocated; has uninitialized values. + // - * `strides` - allocated; has uninitialized values. - /* - ndarray now: - - .ndims: initialized - - .shape: allocated but uninitialized .shape - - .data: uninitialized - */ - - let llvm_usize_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "") - .unwrap(); - let llvm_pdata_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast( - llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(), - llvm_usize, - "", - ) - .unwrap(); - let llvm_elem_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "") - .unwrap(); + let itemsize = ndarray.load_itemsize(ctx); // Same as doing a `ctx.get_llvm_type` on `dtype` and get its `size_of()`. // Allocates a buffer for the initial RPC'ed object, which is guaranteed to be // (4 + 4 * ndims) bytes with 8-byte alignment - let sizeof_dims = - ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); - let buffer_size = - ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap(); + let sizeof_usize = llvm_usize.size_of(); + let sizeof_usize = + ctx.builder.build_int_truncate_or_bit_cast(sizeof_usize, llvm_usize, "").unwrap(); + + let sizeof_ptr = llvm_i8.ptr_type(AddressSpace::default()).size_of(); + let sizeof_ptr = + ctx.builder.build_int_z_extend_or_bit_cast(sizeof_ptr, llvm_usize, "").unwrap(); + + let sizeof_shape = + ctx.builder.build_int_mul(ndarray.load_ndims(ctx), sizeof_usize, "").unwrap(); + + // Size of the buffer for the initial `rpc_recv()`. + let unaligned_buffer_size = + ctx.builder.build_int_add(sizeof_ptr, sizeof_shape, "").unwrap(); let stackptr = call_stacksave(ctx, None); - let buffer = - type_aligned_alloca(generator, ctx, llvm_i8_8, buffer_size, Some("rpc.buffer")); - let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None); + let buffer = type_aligned_alloca( + generator, + ctx, + llvm_i8_8, + unaligned_buffer_size, + Some("rpc.buffer"), + ); + let buffer = ArraySliceValue::from_ptr_val(buffer, unaligned_buffer_size, None); // The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape] // @@ -658,7 +643,7 @@ fn format_rpc_ret<'ctx>( let ndarray_nbytes = ctx .build_call_or_invoke( rpc_recv, - &[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims]. + &[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims] "rpc.size.next", ) .map(BasicValueEnum::into_int_value) @@ -666,16 +651,14 @@ fn format_rpc_ret<'ctx>( // debug_assert(ndarray_nbytes > 0) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let cmp = ctx + .builder + .build_int_compare(IntPredicate::UGT, ndarray_nbytes, num_0, "") + .unwrap(); + ctx.make_assert( generator, - ctx.builder - .build_int_compare( - IntPredicate::UGT, - ndarray_nbytes, - ndarray_nbytes.get_type().const_zero(), - "", - ) - .unwrap(), + cmp, "0:AssertionError", "Unexpected RPC termination for ndarray - Expected data buffer next", [None, None, None], @@ -684,49 +667,50 @@ fn format_rpc_ret<'ctx>( } // Copy shape from the buffer to `ndarray.shape`. - let pbuffer_dims = - unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; + // We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`. + let pbuffer_shape = + unsafe { buffer.ptr_offset_unchecked(ctx, generator, &sizeof_ptr, None) }; + let pbuffer_shape = + ctx.builder.build_pointer_cast(pbuffer_shape, llvm_pusize, "").unwrap(); + + // Copy shape from buffer to `ndarray.shape` + ndarray.copy_shape_from_array(generator, ctx, pbuffer_shape); - call_memcpy_generic( - ctx, - ndarray.shape().base_ptr(ctx, generator), - pbuffer_dims, - sizeof_dims, - llvm_i1.const_zero(), - ); // Restore stack from before allocation of buffer call_stackrestore(ctx, stackptr); // Allocate `ndarray.data`. // `ndarray.shape` must be initialized beforehand in this implementation // (for ndarray.create_data() to know how many elements to allocate) - let num_elements = - call_ndarray_calc_size(generator, ctx, &ndarray.shape(), (None, None)); + unsafe { ndarray.create_data(generator, ctx) }; // NOTE: the strides of `ndarray` has also been set to contiguous in `create_data`. // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let sizeof_data = - ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap(); + let num_elements = ndarray.size(generator, ctx); + + let expected_ndarray_nbytes = + ctx.builder.build_int_mul(num_elements, itemsize, "").unwrap(); + let cmp = ctx + .builder + .build_int_compare( + IntPredicate::UGE, + expected_ndarray_nbytes, + ndarray_nbytes, + "", + ) + .unwrap(); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::UGE, - sizeof_data, - ndarray_nbytes, - "", - ).unwrap(), + cmp, "0:AssertionError", "Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes", - [Some(sizeof_data), Some(ndarray_nbytes), None], + [Some(expected_ndarray_nbytes), Some(ndarray_nbytes), None], ctx.current_loc, ); } - unsafe { ndarray.create_data(generator, ctx) }; - let ndarray_data = ndarray.data().base_ptr(ctx, generator); - let ndarray_data_i8 = - ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap(); // NOTE: Currently on `prehead_bb` ctx.builder.build_unconditional_branch(head_bb).unwrap(); @@ -735,7 +719,7 @@ fn format_rpc_ret<'ctx>( ctx.builder.position_at_end(head_bb); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); - phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]); + phi.add_incoming(&[(&ndarray_data, prehead_bb)]); let alloc_size = ctx .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") @@ -750,12 +734,13 @@ fn format_rpc_ret<'ctx>( ctx.builder.position_at_end(alloc_bb); // Align the allocation to sizeof(T) - let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof); + let alloc_size = round_up(ctx, alloc_size, itemsize); + // TODO(Derppening): Candidate for refactor into type_aligned_alloca let alloc_ptr = ctx .builder .build_array_alloca( - llvm_elem_ty, - ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(), + dtype_llvm, + ctx.builder.build_int_unsigned_div(alloc_size, itemsize, "").unwrap(), "rpc.alloc", ) .unwrap();