From 2e75d5a730d95312afd36b9dad7238ab93f32e5a Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 22 Aug 2024 13:05:03 +0800 Subject: [PATCH 1/4] artiq: reimplement reformat_rpc_arg for ndarray --- nac3artiq/src/codegen.rs | 232 ++++++++++++++++----------------------- 1 file changed, 92 insertions(+), 140 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 7c43d885..41bde3b5 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -14,26 +14,29 @@ use pyo3::{ use nac3core::{ codegen::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayType, - NDArrayValue, ProxyType, ProxyValue, RangeValue, UntypedArrayLikeAccessor, - }, + classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor}, expr::{destructure_range, gen_call}, irrt::call_ndarray_calc_size, - llvm_intrinsics::{call_int_smax, call_memcpy_generic, call_stackrestore, call_stacksave}, + llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave}, + model::*, + object::{any::AnyObject, ndarray::NDArrayObject}, stmt::{gen_block, gen_for_callback_incrementing, gen_if_callback, gen_with}, CodeGenContext, CodeGenerator, }, inkwell::{ context::Context, module::Linkage, - types::{BasicType, IntType}, - values::{BasicValueEnum, IntValue, PointerValue, StructValue}, + types::IntType, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue}, AddressSpace, IntPredicate, OptimizationLevel, }, nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}, symbol_resolver::ValueEnum, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, GenCall}, + toplevel::{ + helper::{extract_ndims, PrimDef}, + numpy::unpack_ndarray_var_tys, + DefinitionId, GenCall, + }, typecheck::typedef::{iter_type_vars, FunSignature, FuncArg, Type, TypeEnum, VarMap}, }; @@ -454,55 +457,42 @@ fn format_rpc_arg<'ctx>( // NAC3: NDArray = { usize, usize*, T* } // libproto_artiq: NDArray = [data[..], dim_sz[..]] - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + let ndarray = AnyObject { ty: arg_ty, value: arg }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, arg_ty); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let llvm_arg_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); - let llvm_arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); + let dtype = ctx.get_llvm_type(generator, ndarray.dtype); + let ndims = ndarray.ndims_llvm(generator, ctx.ctx); - let llvm_usize_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast(llvm_arg_ty.size_type().size_of(), llvm_usize, "") - .unwrap(); - let llvm_pdata_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast( - llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(), - llvm_usize, - "", - ) - .unwrap(); + // `ndarray.data` is possibly not contiguous, and we need it to be contiguous for + // the reader. + // Turning it into a ContiguousNDArray to get a `data` that is contiguous. + let carray = ndarray.make_contiguous_ndarray(generator, ctx, Any(dtype)); - let dims_buf_sz = - ctx.builder.build_int_mul(llvm_arg.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); + let sizeof_sizet = Int(SizeT).size_of(generator, ctx.ctx); + let sizeof_sizet = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_sizet); - let buffer_size = - ctx.builder.build_int_add(dims_buf_sz, llvm_pdata_sizeof, "").unwrap(); + let sizeof_pdata = Ptr(Any(dtype)).size_of(generator, ctx.ctx); + let sizeof_pdata = Int(SizeT).truncate_or_bit_cast(generator, ctx, sizeof_pdata); - let buffer = ctx.builder.build_array_alloca(llvm_i8, buffer_size, "rpc.arg").unwrap(); - let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, Some("rpc.arg")); + let sizeof_buf_shape = sizeof_sizet.mul(ctx, ndims); + let sizeof_buf = sizeof_buf_shape.add(ctx, sizeof_pdata); - call_memcpy_generic( - ctx, - buffer.base_ptr(ctx, generator), - llvm_arg.ptr_to_data(ctx), - llvm_pdata_sizeof, - llvm_i1.const_zero(), - ); + // buf = { data: void*, shape: [size_t; ndims]; } + let buf = Int(Byte).array_alloca(generator, ctx, sizeof_buf.value); + let buf_data = buf; + let buf_shape = buf_data.offset(ctx, sizeof_pdata.value); - let pbuffer_dims_begin = - unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; - call_memcpy_generic( - ctx, - pbuffer_dims_begin, - llvm_arg.dim_sizes().base_ptr(ctx, generator), - dims_buf_sz, - llvm_i1.const_zero(), - ); + // Write to `buf->data` + let carray_data = carray.get(generator, ctx, |f| f.data); // has type Ptr + let carray_data = carray_data.pointer_cast(generator, ctx, Int(Byte)); + buf_data.copy_from(generator, ctx, carray_data, sizeof_pdata.value); - buffer.base_ptr(ctx, generator) + // Write to `buf->shape` + let carray_shape = ndarray.instance.get(generator, ctx, |f| f.shape); + let carray_shape_i8 = carray_shape.pointer_cast(generator, ctx, Int(Byte)); + buf_shape.copy_from(generator, ctx, carray_shape_i8, sizeof_buf_shape.value); + + buf.value } _ => { @@ -563,8 +553,10 @@ fn format_rpc_ret<'ctx>( let result = match &*ctx.unifier.get_ty_immutable(ret_ty) { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + // FIXME: It is possible to rewrite everything more neatly with `Model<'ctx>`, but this is not too important. + + let num_0 = Int(SizeT).const_0(generator, ctx.ctx); + let num_8 = Int(SizeT).const_int(generator, ctx.ctx, 8, false); // Round `val` up to its modulo `power_of_two` let round_up = |ctx: &mut CodeGenContext<'ctx, '_>, @@ -590,60 +582,36 @@ fn format_rpc_ret<'ctx>( .unwrap() }; - // Setup types - let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); - let llvm_elem_ty = ctx.get_llvm_type(generator, elem_ty); - let llvm_ret_ty = NDArrayType::new(generator, ctx.ctx, llvm_elem_ty); - // Allocate the resulting ndarray // A condition after format_rpc_ret ensures this will not be popped this off. - let ndarray = llvm_ret_ty.new_value(generator, ctx, Some("rpc.result")); + let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ret_ty); + let ndims = extract_ndims(&ctx.unifier, ndims); + let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims); - // Setup ndims - let ndims = - if let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) { - assert_eq!(values.len(), 1); + // NOTE: Current content of `ndarray`: + // - * `data` - **NOT YET** allocated. + // - * `itemsize` - initialized to be size_of(dtype). + // - * `ndims` - initialized. + // - * `shape` - allocated; has uninitialized values. + // - * `strides` - allocated; has uninitialized values. - u64::try_from(values[0].clone()).unwrap() - } else { - unreachable!(); - }; - // Set `ndarray.ndims` - ndarray.store_ndims(ctx, generator, llvm_usize.const_int(ndims, false)); - // Allocate `ndarray.shape` [size_t; ndims] - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray.load_ndims(ctx)); - - /* - ndarray now: - - .ndims: initialized - - .shape: allocated but uninitialized .shape - - .data: uninitialized - */ - - let llvm_usize_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast(llvm_usize.size_of(), llvm_usize, "") - .unwrap(); - let llvm_pdata_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast( - llvm_elem_ty.ptr_type(AddressSpace::default()).size_of(), - llvm_usize, - "", - ) - .unwrap(); - let llvm_elem_sizeof = ctx - .builder - .build_int_truncate_or_bit_cast(llvm_elem_ty.size_of().unwrap(), llvm_usize, "") - .unwrap(); + let itemsize = ndarray.instance.get(generator, ctx, |f| f.itemsize); // Same as doing a `ctx.get_llvm_type` on `dtype` and get its `size_of()`. + let dtype_llvm = ctx.get_llvm_type(generator, dtype); // Allocates a buffer for the initial RPC'ed object, which is guaranteed to be // (4 + 4 * ndims) bytes with 8-byte alignment - let sizeof_dims = - ctx.builder.build_int_mul(ndarray.load_ndims(ctx), llvm_usize_sizeof, "").unwrap(); - let unaligned_buffer_size = - ctx.builder.build_int_add(sizeof_dims, llvm_pdata_sizeof, "").unwrap(); - let buffer_size = round_up(ctx, unaligned_buffer_size, llvm_usize.const_int(8, false)); + let sizeof_size_t = Int(SizeT).size_of(generator, ctx.ctx); + let sizeof_size_t = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_size_t); // sizeof(size_t) + + let sizeof_ptr = Ptr(Int(Byte)).size_of(generator, ctx.ctx); + let sizeof_ptr = Int(SizeT).z_extend_or_truncate(generator, ctx, sizeof_ptr); // sizeof(uint8_t*) + + let sizeof_shape = ndarray.ndims_llvm(generator, ctx.ctx).mul(ctx, sizeof_size_t); // sizeof([size_t; ndims]); same as the # of bytes of `ndarray.shape`. + + // Size of the buffer for the initial `rpc_recv()`. + let unaligned_buffer_size = sizeof_ptr.add(ctx, sizeof_shape); // sizeof(uint8_t*) + sizeof([size_t; ndims]). + let buffer_size = round_up(ctx, unaligned_buffer_size.value, num_8.value); + let buffer_size = unsafe { Int(SizeT).believe_value(buffer_size) }; let stackptr = call_stacksave(ctx, None); // Just to be absolutely sure, alloca in [i8 x 8] slices to force 8-byte alignment @@ -651,9 +619,7 @@ fn format_rpc_ret<'ctx>( .builder .build_array_alloca( llvm_i8_8, - ctx.builder - .build_int_unsigned_div(buffer_size, llvm_usize.const_int(8, false), "") - .unwrap(), + ctx.builder.build_int_unsigned_div(buffer_size.value, num_8.value, "").unwrap(), "rpc.buffer", ) .unwrap(); @@ -662,7 +628,7 @@ fn format_rpc_ret<'ctx>( .build_bit_cast(buffer, llvm_pi8, "") .map(BasicValueEnum::into_pointer_value) .unwrap(); - let buffer = ArraySliceValue::from_ptr_val(buffer, buffer_size, None); + let buffer = unsafe { Ptr(Int(Byte)).believe_value(buffer) }; // The first call to `rpc_recv` reads the top-level ndarray object: [pdata, shape] // @@ -670,24 +636,20 @@ fn format_rpc_ret<'ctx>( let ndarray_nbytes = ctx .build_call_or_invoke( rpc_recv, - &[buffer.base_ptr(ctx, generator).into()], // Reads [usize; ndims]. NOTE: We are allocated [size_t; ndims]. + &[buffer.value.into()], // Reads [usize; ndims] "rpc.size.next", ) .map(BasicValueEnum::into_int_value) .unwrap(); + let ndarray_nbytes = unsafe { Int(SizeT).believe_value(ndarray_nbytes) }; // debug_assert(ndarray_nbytes > 0) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let cmp = ndarray_nbytes.compare(ctx, IntPredicate::UGT, num_0); + ctx.make_assert( generator, - ctx.builder - .build_int_compare( - IntPredicate::UGT, - ndarray_nbytes, - ndarray_nbytes.get_type().const_zero(), - "", - ) - .unwrap(), + cmp.value, "0:AssertionError", "Unexpected RPC termination for ndarray - Expected data buffer next", [None, None, None], @@ -696,49 +658,39 @@ fn format_rpc_ret<'ctx>( } // Copy shape from the buffer to `ndarray.shape`. - let pbuffer_dims = - unsafe { buffer.ptr_offset_unchecked(ctx, generator, &llvm_pdata_sizeof, None) }; + // We need to skip the first `sizeof(uint8_t*)` bytes to skip the `pdata` in `[pdata, shape]`. + let pbuffer_shape = buffer.offset(ctx, sizeof_ptr.value); + let pbuffer_shape = pbuffer_shape.pointer_cast(generator, ctx, Int(SizeT)); + + // Copy shape from buffer to `ndarray.shape` + ndarray.copy_shape_from_array(generator, ctx, pbuffer_shape); - call_memcpy_generic( - ctx, - ndarray.dim_sizes().base_ptr(ctx, generator), - pbuffer_dims, - sizeof_dims, - llvm_i1.const_zero(), - ); // Restore stack from before allocation of buffer call_stackrestore(ctx, stackptr); // Allocate `ndarray.data`. // `ndarray.shape` must be initialized beforehand in this implementation // (for ndarray.create_data() to know how many elements to allocate) - let num_elements = - call_ndarray_calc_size(generator, ctx, &ndarray.dim_sizes(), (None, None)); + ndarray.create_data(generator, ctx); // NOTE: the strides of `ndarray` has also been set to contiguous in `::create_data()`. // debug_assert(nelems * sizeof(T) >= ndarray_nbytes) if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let sizeof_data = - ctx.builder.build_int_mul(num_elements, llvm_elem_sizeof, "").unwrap(); + let num_elements = ndarray.size(generator, ctx); + + let expected_ndarray_nbytes = num_elements.mul(ctx, itemsize); + let cmp = expected_ndarray_nbytes.compare(ctx, IntPredicate::UGE, ndarray_nbytes); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::UGE, - sizeof_data, - ndarray_nbytes, - "", - ).unwrap(), + cmp.value, "0:AssertionError", "Unexpected allocation size request for ndarray data - Expected up to {0} bytes, got {1} bytes", - [Some(sizeof_data), Some(ndarray_nbytes), None], + [Some(expected_ndarray_nbytes.value), Some(ndarray_nbytes.value), None], ctx.current_loc, ); } - ndarray.create_data(ctx, llvm_elem_ty, num_elements); - - let ndarray_data = ndarray.data().base_ptr(ctx, generator); - let ndarray_data_i8 = - ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap(); + let ndarray_data = ndarray.instance.get(generator, ctx, |f| f.data); // NOTE: Currently on `prehead_bb` ctx.builder.build_unconditional_branch(head_bb).unwrap(); @@ -747,7 +699,7 @@ fn format_rpc_ret<'ctx>( ctx.builder.position_at_end(head_bb); let phi = ctx.builder.build_phi(llvm_pi8, "rpc.ptr").unwrap(); - phi.add_incoming(&[(&ndarray_data_i8, prehead_bb)]); + phi.add_incoming(&[(&ndarray_data.value, prehead_bb)]); let alloc_size = ctx .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") @@ -762,12 +714,12 @@ fn format_rpc_ret<'ctx>( ctx.builder.position_at_end(alloc_bb); // Align the allocation to sizeof(T) - let alloc_size = round_up(ctx, alloc_size, llvm_elem_sizeof); + let alloc_size = round_up(ctx, alloc_size, itemsize.value); let alloc_ptr = ctx .builder .build_array_alloca( - llvm_elem_ty, - ctx.builder.build_int_unsigned_div(alloc_size, llvm_elem_sizeof, "").unwrap(), + dtype_llvm, + ctx.builder.build_int_unsigned_div(alloc_size, itemsize.value, "").unwrap(), "rpc.alloc", ) .unwrap(); @@ -777,7 +729,7 @@ fn format_rpc_ret<'ctx>( ctx.builder.build_unconditional_branch(head_bb).unwrap(); ctx.builder.position_at_end(tail_bb); - ndarray.as_base_value().into() + ndarray.instance.value.as_basic_value_enum() } _ => { -- 2.44.2 From 2a6ee503bac63af1d574adb653d6298f1a30eb37 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 22 Aug 2024 13:19:39 +0800 Subject: [PATCH 2/4] artiq: reimplement polymorphic_print for ndarray --- nac3artiq/src/codegen.rs | 79 +++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 45 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 41bde3b5..46707f4d 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -14,9 +14,8 @@ use pyo3::{ use nac3core::{ codegen::{ - classes::{ListValue, NDArrayValue, RangeValue, UntypedArrayLikeAccessor}, + classes::{ListValue, RangeValue, UntypedArrayLikeAccessor}, expr::{destructure_range, gen_call}, - irrt::call_ndarray_calc_size, llvm_intrinsics::{call_int_smax, call_stackrestore, call_stacksave}, model::*, object::{any::AnyObject, ndarray::NDArrayObject}, @@ -1311,56 +1310,46 @@ fn polymorphic_print<'ctx>( } TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); - fmt.push_str("array(["); flush(ctx, generator, &mut fmt, &mut args); - let val = NDArrayValue::from_ptr_val(value.into_pointer_value(), llvm_usize, None); - let len = call_ndarray_calc_size(generator, ctx, &val.dim_sizes(), (None, None)); - let last = - ctx.builder.build_int_sub(len, llvm_usize.const_int(1, false), "").unwrap(); + let ndarray = AnyObject { ty, value }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let elem = unsafe { val.data().get_unchecked(ctx, generator, &i, None) }; + let num_0 = Int(SizeT).const_0(generator, ctx.ctx); - polymorphic_print( - ctx, - generator, - &[(elem_ty, elem.into())], - "", - None, - true, - as_rtio, - )?; + // Print `ndarray` as a flat list delimited by interspersed with ", \0" + ndarray.foreach(generator, ctx, |generator, ctx, _, hdl| { + let i = hdl.get_index(generator, ctx); + let scalar = hdl.get_scalar(generator, ctx); - gen_if_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::ULT, i, last, "") - .unwrap()) - }, - |generator, ctx| { - printf(ctx, generator, ", \0".into(), Vec::default()); + // if (i != 0) { puts(", "); } + gen_if_callback( + generator, + ctx, + |_, ctx| { + let not_first = i.compare(ctx, IntPredicate::NE, num_0); + Ok(not_first.value) + }, + |generator, ctx| { + printf(ctx, generator, ", \0".into(), Vec::default()); + Ok(()) + }, + |_, _| Ok(()), + )?; - Ok(()) - }, - |_, _| Ok(()), - )?; - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; + // Print element + polymorphic_print( + ctx, + generator, + &[(scalar.ty, scalar.value.into())], + "", + None, + true, + as_rtio, + )?; + Ok(()) + })?; fmt.push_str(")]"); flush(ctx, generator, &mut fmt, &mut args); -- 2.44.2 From 7ef934722ec0608e9ccdf623a08093f2032314d2 Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 22 Aug 2024 16:19:09 +0800 Subject: [PATCH 3/4] artiq: reimplement get_obj_value to use ndarray with strides --- nac3artiq/src/symbol_resolver.rs | 170 +++++++++++++-------- nac3core/src/codegen/object/ndarray/mod.rs | 15 ++ 2 files changed, 118 insertions(+), 67 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index fd8ed0db..3f305bd1 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -10,18 +10,19 @@ use itertools::Itertools; use parking_lot::RwLock; use pyo3::{ types::{PyDict, PyTuple}, - PyAny, PyObject, PyResult, Python, + PyAny, PyErr, PyObject, PyResult, Python, }; use nac3core::{ codegen::{ - classes::{NDArrayType, ProxyType}, + model::*, + object::ndarray::{make_contiguous_strides, NDArray}, CodeGenContext, CodeGenerator, }, inkwell::{ module::Linkage, - types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + types::BasicType, + values::{BasicValue, BasicValueEnum}, AddressSpace, }, nac3parser::ast::{self, StrRef}, @@ -1088,15 +1089,12 @@ impl InnerResolver { let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); - let llvm_usize = generator.get_size_type(ctx.ctx); - let ndarray_dtype_llvm_ty = ctx.get_llvm_type(generator, ndarray_dtype); - let ndarray_llvm_ty = NDArrayType::new(generator, ctx.ctx, ndarray_dtype_llvm_ty); - + let dtype = Any(ctx.get_llvm_type(generator, ndarray_dtype)); { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + Struct(NDArray).llvm_type(generator, ctx.ctx), Some(AddressSpace::default()), &id_str, ) @@ -1116,100 +1114,138 @@ impl InnerResolver { } else { todo!("Unpacking literal of more than one element unimplemented") }; - let Ok(ndarray_ndims) = u64::try_from(ndarray_ndims) else { + let Ok(ndims) = u64::try_from(ndarray_ndims) else { unreachable!("Expected u64 value for ndarray_ndims") }; // Obtain the shape of the ndarray let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; - assert_eq!(shape_tuple.len(), ndarray_ndims as usize); - let shape_values: Result>, _> = shape_tuple + assert_eq!(shape_tuple.len(), ndims as usize); + + // The Rust type inferencer cannot figure this out + let shape_values: Result>>, PyErr> = shape_tuple .iter() .enumerate() .map(|(i, elem)| { - self.get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()).map_err( - |e| super::CompileError::new_err(format!("Error getting element {i}: {e}")), - ) + let value = self + .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + })? + .unwrap(); + let value = Int(SizeT).check_value(generator, ctx.ctx, value).unwrap(); + Ok(value) }) .collect(); - let shape_values = shape_values?.unwrap(); - let shape_values = llvm_usize.const_array( - &shape_values.into_iter().map(BasicValueEnum::into_int_value).collect_vec(), - ); + let shape_values = shape_values?; + + // Also use this opportunity to get the constant values of `shape_values` for calculating strides. + let shape_u64s = shape_values + .iter() + .map(|dim| { + assert!(dim.value.is_const()); + dim.value.get_zero_extended_constant().unwrap() + }) + .collect_vec(); + let shape_values = Int(SizeT).const_array(generator, ctx.ctx, &shape_values); // create a global for ndarray.shape and initialize it using the shape let shape_global = ctx.module.add_global( - llvm_usize.array_type(ndarray_ndims as u32), + Array { len: AnyLen(ndims as u32), item: Int(SizeT) }.llvm_type(generator, ctx.ctx), Some(AddressSpace::default()), &(id_str.clone() + ".shape"), ); - shape_global.set_initializer(&shape_values); + shape_global.set_initializer(&shape_values.value); // Obtain the (flattened) elements of the ndarray let sz: usize = obj.getattr("size")?.extract()?; - let data: Result>, _> = (0..sz) + let data_values: Vec> = (0..sz) .map(|i| { obj.getattr("flat")?.get_item(i).and_then(|elem| { - self.get_obj_value(py, elem, ctx, generator, ndarray_dtype).map_err(|e| { - super::CompileError::new_err(format!("Error getting element {i}: {e}")) - }) + let value = self + .get_obj_value(py, elem, ctx, generator, ndarray_dtype) + .map_err(|e| { + super::CompileError::new_err(format!( + "Error getting element {i}: {e}" + )) + })? + .unwrap(); + + let value = dtype.check_value(generator, ctx.ctx, value).unwrap(); + Ok(value) }) }) - .collect(); - let data = data?.unwrap().into_iter(); - let data = match ndarray_dtype_llvm_ty { - BasicTypeEnum::ArrayType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) - } - - BasicTypeEnum::FloatType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) - } - - BasicTypeEnum::IntType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) - } - - BasicTypeEnum::PointerType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) - } - - BasicTypeEnum::StructType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) - } - - BasicTypeEnum::VectorType(_) => unreachable!(), - }; + .try_collect()?; + let data = dtype.const_array(generator, ctx.ctx, &data_values); // create a global for ndarray.data and initialize it using the elements + // + // NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`. + // We will have to cast it to an `u8*` later. let data_global = ctx.module.add_global( - ndarray_dtype_llvm_ty.array_type(sz as u32), + Array { len: AnyLen(sz as u32), item: dtype }.llvm_type(generator, ctx.ctx), Some(AddressSpace::default()), &(id_str.clone() + ".data"), ); - data_global.set_initializer(&data); + data_global.set_initializer(&data.value); + + // Get the constant itemsize. + let itemsize = dtype.llvm_type(generator, ctx.ctx).size_of().unwrap(); + let itemsize = itemsize.get_zero_extended_constant().unwrap(); + + // Create the strides needed for ndarray.strides + let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); + let strides = strides + .into_iter() + .map(|stride| Int(SizeT).const_int(generator, ctx.ctx, stride, false)) + .collect_vec(); + let strides = Int(SizeT).const_array(generator, ctx.ctx, &strides); + + // create a global for ndarray.strides and initialize it + let strides_global = ctx.module.add_global( + Array { len: AnyLen(ndims as u32), item: Int(Byte) }.llvm_type(generator, ctx.ctx), + Some(AddressSpace::default()), + &(id_str.clone() + ".strides"), + ); + strides_global.set_initializer(&strides.value); // create a global for the ndarray object and initialize it - let value = ndarray_llvm_ty.as_underlying_type().const_named_struct(&[ - llvm_usize.const_int(ndarray_ndims, false).into(), - shape_global - .as_pointer_value() - .const_cast(llvm_usize.ptr_type(AddressSpace::default())) - .into(), - data_global - .as_pointer_value() - .const_cast(ndarray_dtype_llvm_ty.ptr_type(AddressSpace::default())) - .into(), - ]); + // We are also doing [`Model::check_value`] instead of [`Model::believe_value`] to catch bugs. - let ndarray = ctx.module.add_global( - ndarray_llvm_ty.as_underlying_type(), + // NOTE: data_global is an array of dtype, we want a `u8*`. + let ndarray_data = Ptr(dtype).check_value(generator, ctx.ctx, data_global).unwrap(); + let ndarray_data = Ptr(Int(Byte)).pointer_cast(generator, ctx, ndarray_data.value); + + let ndarray_itemsize = Int(SizeT).const_int(generator, ctx.ctx, itemsize, false); + + let ndarray_ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims, false); + + let ndarray_shape = + Ptr(Int(SizeT)).check_value(generator, ctx.ctx, shape_global).unwrap(); + + let ndarray_strides = + Ptr(Int(SizeT)).check_value(generator, ctx.ctx, strides_global).unwrap(); + + let ndarray = Struct(NDArray).const_struct( + generator, + ctx.ctx, + &[ + ndarray_data.value.as_basic_value_enum(), + ndarray_itemsize.value.as_basic_value_enum(), + ndarray_ndims.value.as_basic_value_enum(), + ndarray_shape.value.as_basic_value_enum(), + ndarray_strides.value.as_basic_value_enum(), + ], + ); + + let ndarray_global = ctx.module.add_global( + Struct(NDArray).llvm_type(generator, ctx.ctx), Some(AddressSpace::default()), &id_str, ); - ndarray.set_initializer(&value); + ndarray_global.set_initializer(&ndarray.value); - Ok(Some(ndarray.as_pointer_value().into())) + Ok(Some(ndarray_global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 5d3684a3..25fbd4e5 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -652,3 +652,18 @@ impl<'ctx> NDArrayOut<'ctx> { } } } + +/// A version of [`call_nac3_ndarray_set_strides_by_shape`] in Rust. +/// +/// This function is used generating strides for globally defined contiguous ndarrays. +#[must_use] +pub fn make_contiguous_strides(itemsize: u64, ndims: u64, shape: &[u64]) -> Vec { + let mut strides = Vec::with_capacity(ndims as usize); + let mut stride_product = 1u64; + for i in 0..ndims { + let axis = ndims - i - 1; + strides[axis as usize] = stride_product * itemsize; + stride_product *= shape[axis as usize]; + } + strides +} -- 2.44.2 From 33555be7e02e9ed53a310d10644241977a4d75df Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 22 Aug 2024 16:21:01 +0800 Subject: [PATCH 4/4] core: remove old ndarray code and NDArray proxy Nothing depends on the old ndarray implementation now. --- nac3core/irrt/irrt.cpp | 1 - nac3core/irrt/irrt/ndarray.hpp | 151 ---- nac3core/src/codegen/classes.rs | 630 +------------ nac3core/src/codegen/irrt/mod.rs | 376 +------- nac3core/src/codegen/numpy.rs | 1423 +----------------------------- nac3core/src/codegen/test.rs | 14 +- 6 files changed, 8 insertions(+), 2587 deletions(-) delete mode 100644 nac3core/irrt/irrt/ndarray.hpp diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 70ef2392..e652dbaf 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -2,7 +2,6 @@ #include "irrt/int_types.hpp" #include "irrt/list.hpp" #include "irrt/math.hpp" -#include "irrt/ndarray.hpp" #include "irrt/range.hpp" #include "irrt/slice.hpp" #include "irrt/ndarray/basic.hpp" diff --git a/nac3core/irrt/irrt/ndarray.hpp b/nac3core/irrt/irrt/ndarray.hpp deleted file mode 100644 index 72ca0b9e..00000000 --- a/nac3core/irrt/irrt/ndarray.hpp +++ /dev/null @@ -1,151 +0,0 @@ -#pragma once - -#include "irrt/int_types.hpp" - -// TODO: To be deleted since NDArray with strides is done. - -namespace { -template -SizeT __nac3_ndarray_calc_size_impl(const SizeT* list_data, SizeT list_len, SizeT begin_idx, SizeT end_idx) { - __builtin_assume(end_idx <= list_len); - - SizeT num_elems = 1; - for (SizeT i = begin_idx; i < end_idx; ++i) { - SizeT val = list_data[i]; - __builtin_assume(val > 0); - num_elems *= val; - } - return num_elems; -} - -template -void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT num_dims, NDIndexInt* idxs) { - SizeT stride = 1; - for (SizeT dim = 0; dim < num_dims; dim++) { - SizeT i = num_dims - dim - 1; - __builtin_assume(dims[i] > 0); - idxs[i] = (index / stride) % dims[i]; - stride *= dims[i]; - } -} - -template -SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims, - SizeT num_dims, - const NDIndexInt* indices, - SizeT num_indices) { - SizeT idx = 0; - SizeT stride = 1; - for (SizeT i = 0; i < num_dims; ++i) { - SizeT ri = num_dims - i - 1; - if (ri < num_indices) { - idx += stride * indices[ri]; - } - - __builtin_assume(dims[i] > 0); - stride *= dims[ri]; - } - return idx; -} - -template -void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims, - SizeT lhs_ndims, - const SizeT* rhs_dims, - SizeT rhs_ndims, - SizeT* out_dims) { - SizeT max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims; - - for (SizeT i = 0; i < max_ndims; ++i) { - const SizeT* lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : nullptr; - const SizeT* rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : nullptr; - SizeT* out_dim = &out_dims[max_ndims - i - 1]; - - if (lhs_dim_sz == nullptr) { - *out_dim = *rhs_dim_sz; - } else if (rhs_dim_sz == nullptr) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == 1) { - *out_dim = *rhs_dim_sz; - } else if (*rhs_dim_sz == 1) { - *out_dim = *lhs_dim_sz; - } else if (*lhs_dim_sz == *rhs_dim_sz) { - *out_dim = *lhs_dim_sz; - } else { - __builtin_unreachable(); - } - } -} - -template -void __nac3_ndarray_calc_broadcast_idx_impl(const SizeT* src_dims, - SizeT src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - for (SizeT i = 0; i < src_ndims; ++i) { - SizeT src_i = src_ndims - i - 1; - out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i]; - } -} -} // namespace - -extern "C" { -uint32_t __nac3_ndarray_calc_size(const uint32_t* list_data, uint32_t list_len, uint32_t begin_idx, uint32_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -uint64_t -__nac3_ndarray_calc_size64(const uint64_t* list_data, uint64_t list_len, uint64_t begin_idx, uint64_t end_idx) { - return __nac3_ndarray_calc_size_impl(list_data, list_len, begin_idx, end_idx); -} - -void __nac3_ndarray_calc_nd_indices(uint32_t index, const uint32_t* dims, uint32_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint64_t num_dims, NDIndexInt* idxs) { - __nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs); -} - -uint32_t -__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims, - uint64_t num_dims, - const NDIndexInt* indices, - uint64_t num_indices) { - return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices); -} - -void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims, - uint32_t lhs_ndims, - const uint32_t* rhs_dims, - uint32_t rhs_ndims, - uint32_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast64(const uint64_t* lhs_dims, - uint64_t lhs_ndims, - const uint64_t* rhs_dims, - uint64_t rhs_ndims, - uint64_t* out_dims) { - return __nac3_ndarray_calc_broadcast_impl(lhs_dims, lhs_ndims, rhs_dims, rhs_ndims, out_dims); -} - -void __nac3_ndarray_calc_broadcast_idx(const uint32_t* src_dims, - uint32_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} - -void __nac3_ndarray_calc_broadcast_idx64(const uint64_t* src_dims, - uint64_t src_ndims, - const NDIndexInt* in_idx, - NDIndexInt* out_idx) { - __nac3_ndarray_calc_broadcast_idx_impl(src_dims, src_ndims, in_idx, out_idx); -} -} \ No newline at end of file diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 8628aaa7..8ad48a2e 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -5,12 +5,7 @@ use inkwell::{ AddressSpace, IntPredicate, }; -use super::{ - irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, - llvm_intrinsics::call_int_umin, - stmt::gen_for_callback_incrementing, - CodeGenContext, CodeGenerator, -}; +use super::{CodeGenContext, CodeGenerator}; /// A LLVM type that is used to represent a non-primitive type in NAC3. pub trait ProxyType<'ctx>: Into { @@ -1140,626 +1135,3 @@ impl<'ctx> From> for PointerValue<'ctx> { value.as_base_value() } } - -/// Proxy type for a `ndarray` type in LLVM. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct NDArrayType<'ctx> { - ty: PointerType<'ctx>, - llvm_usize: IntType<'ctx>, -} - -impl<'ctx> NDArrayType<'ctx> { - /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. - pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - let llvm_ndarray_ty = llvm_ty.get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}")); - }; - if llvm_ndarray_ty.count_fields() != 3 { - return Err(format!( - "Expected 3 fields in `NDArray`, got {}", - llvm_ndarray_ty.count_fields() - )); - } - - let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap(); - let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else { - return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}")); - }; - if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected {}-bit int type for `ndarray.0`, got {}-bit int", - llvm_usize.get_bit_width(), - ndarray_ndims_ty.get_bit_width() - )); - } - - let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap(); - let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else { - return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}")); - }; - let ndarray_dims = ndarray_pdims.get_element_type(); - let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else { - return Err(format!( - "Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}" - )); - }; - if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() { - return Err(format!( - "Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int", - llvm_usize.get_bit_width(), - ndarray_dims.get_bit_width() - )); - } - - let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap(); - let Ok(_) = PointerType::try_from(ndarray_data_ty) else { - return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}")); - }; - - Ok(()) - } - - /// Creates an instance of [`ListType`]. - #[must_use] - pub fn new( - generator: &G, - ctx: &'ctx Context, - dtype: BasicTypeEnum<'ctx>, - ) -> Self { - let llvm_usize = generator.get_size_type(ctx); - - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } - // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let llvm_ndarray = ctx - .struct_type( - &[ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), - ], - false, - ) - .ptr_type(AddressSpace::default()); - - NDArrayType::from_type(llvm_ndarray, llvm_usize) - } - - /// Creates an [`NDArrayType`] from a [`PointerType`]. - #[must_use] - pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self { - debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok()); - - NDArrayType { ty: ptr_ty, llvm_usize } - } - - /// Returns the type of the `size` field of this `ndarray` type. - #[must_use] - pub fn size_type(&self) -> IntType<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(0) - .map(BasicTypeEnum::into_int_type) - .unwrap() - } - - /// Returns the element type of this `ndarray` type. - #[must_use] - pub fn element_type(&self) -> AnyTypeEnum<'ctx> { - self.as_base_type() - .get_element_type() - .into_struct_type() - .get_field_type_at_index(2) - .map(BasicTypeEnum::into_pointer_type) - .map(PointerType::get_element_type) - .unwrap() - } -} - -impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> { - type Base = PointerType<'ctx>; - type Underlying = StructType<'ctx>; - type Value = NDArrayValue<'ctx>; - - fn new_value( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Value { - self.create_value( - generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(), - name, - ) - } - - fn create_value( - &self, - value: >::Base, - name: Option<&'ctx str>, - ) -> Self::Value { - debug_assert_eq!(value.get_type(), self.as_base_type()); - - NDArrayValue { value, llvm_usize: self.llvm_usize, name } - } - - fn as_base_type(&self) -> Self::Base { - self.ty - } - - fn as_underlying_type(&self) -> Self::Underlying { - self.as_base_type().get_element_type().into_struct_type() - } -} - -impl<'ctx> From> for PointerType<'ctx> { - fn from(value: NDArrayType<'ctx>) -> Self { - value.as_base_type() - } -} - -/// Proxy type for accessing an `NDArray` value in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayValue<'ctx> { - value: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, -} - -impl<'ctx> NDArrayValue<'ctx> { - /// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an - /// instance. - pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> { - NDArrayType::is_type(value.get_type(), llvm_usize) - } - - /// Creates an [`NDArrayValue`] from a [`PointerValue`]. - #[must_use] - pub fn from_ptr_val( - ptr: PointerValue<'ctx>, - llvm_usize: IntType<'ctx>, - name: Option<&'ctx str>, - ) -> Self { - debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok()); - - >::Type::from_type(ptr.get_type(), llvm_usize) - .create_value(ptr, name) - } - - /// Returns the pointer to the field storing the number of dimensions of this `NDArray`. - fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the number of dimensions `ndims` into this instance. - pub fn store_ndims( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ndims: IntValue<'ctx>, - ) { - debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx)); - - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_store(pndims, ndims).unwrap(); - } - - /// Returns the number of dimensions of this `NDArray` as a value. - pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let pndims = self.ptr_to_ndims(ctx); - ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap() - } - - /// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr` - /// on the field. - fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of dimension sizes `dims` into this instance. - fn store_dim_sizes(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_dims(ctx), dims).unwrap(); - } - - /// Convenience method for creating a new array storing dimension sizes with the given `size`. - pub fn create_dim_sizes( - &self, - ctx: &CodeGenContext<'ctx, '_>, - llvm_usize: IntType<'ctx>, - size: IntValue<'ctx>, - ) { - self.store_dim_sizes(ctx, ctx.builder.build_array_alloca(llvm_usize, size, "").unwrap()); - } - - /// Returns a proxy object to the field storing the size of each dimension of this `NDArray`. - #[must_use] - pub fn dim_sizes(&self) -> NDArrayDimsProxy<'ctx, '_> { - NDArrayDimsProxy(self) - } - - /// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr` - /// on the field. - pub fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let llvm_i32 = ctx.ctx.i32_type(); - let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.as_base_value(), - &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], - var_name.as_str(), - ) - .unwrap() - } - } - - /// Stores the array of data elements `data` into this instance. - fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) { - ctx.builder.build_store(self.ptr_to_data(ctx), data).unwrap(); - } - - /// Convenience method for creating a new array storing data elements with the given element - /// type `elem_ty` and `size`. - pub fn create_data( - &self, - ctx: &CodeGenContext<'ctx, '_>, - elem_ty: BasicTypeEnum<'ctx>, - size: IntValue<'ctx>, - ) { - self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, "").unwrap()); - } - - /// Returns a proxy object to the field storing the data of this `NDArray`. - #[must_use] - pub fn data(&self) -> NDArrayDataProxy<'ctx, '_> { - NDArrayDataProxy(self) - } -} - -impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> { - type Base = PointerValue<'ctx>; - type Underlying = StructValue<'ctx>; - type Type = NDArrayType<'ctx>; - - fn get_type(&self) -> Self::Type { - NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize) - } - - fn as_base_value(&self) -> Self::Base { - self.value - } - - fn as_underlying_value( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - name: Option<&'ctx str>, - ) -> Self::Underlying { - ctx.builder - .build_load(self.as_base_value(), name.unwrap_or_default()) - .map(BasicValueEnum::into_struct_value) - .unwrap() - } -} - -impl<'ctx> From> for PointerValue<'ctx> { - fn from(value: NDArrayValue<'ctx>) -> Self { - value.as_base_value() - } -} - -/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayDimsProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> { - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.dim_sizes().base_ptr(ctx, generator).get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_dims(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> IntValue<'ctx> { - self.0.load_ndims(ctx) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default(); - - unsafe { - ctx.builder - .build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str()) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let size = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "index {0} is out of bounds for axis 0 with size {1}", - [Some(*idx), Some(self.0.load_ndims(ctx)), None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> {} - -impl<'ctx> TypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - fn downcast_to_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, - ) -> IntValue<'ctx> { - value.into_int_value() - } -} - -impl<'ctx> TypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_> { - fn upcast_from_type( - &self, - _: &mut CodeGenContext<'ctx, '_>, - value: IntValue<'ctx>, - ) -> BasicValueEnum<'ctx> { - value.into() - } -} - -/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM. -#[derive(Copy, Clone)] -pub struct NDArrayDataProxy<'ctx, 'a>(&'a NDArrayValue<'ctx>); - -impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> { - fn element_type( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> AnyTypeEnum<'ctx> { - self.0.data().base_ptr(ctx, generator).get_type().get_element_type() - } - - fn base_ptr( - &self, - ctx: &CodeGenContext<'ctx, '_>, - _: &G, - ) -> PointerValue<'ctx> { - let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default(); - - ctx.builder - .build_load(self.0.ptr_to_data(ctx), var_name.as_str()) - .map(BasicValueEnum::into_pointer_value) - .unwrap() - } - - fn size( - &self, - ctx: &CodeGenContext<'ctx, '_>, - generator: &G, - ) -> IntValue<'ctx> { - call_ndarray_calc_size(generator, ctx, &self.as_slice_value(ctx, generator), (None, None)) - } -} - -impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> { - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[*idx], - name.unwrap_or_default(), - ) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - idx: &IntValue<'ctx>, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let data_sz = self.size(ctx, generator); - let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap(); - ctx.make_assert( - generator, - in_range, - "0:IndexError", - "index {0} is out of bounds with size {1}", - [Some(*idx), Some(self.0.load_ndims(ctx)), None], - ctx.current_loc, - ); - - unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) } - } -} - -impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} -impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {} - -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ - unsafe fn ptr_offset_unchecked( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - indices: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let indices_elem_ty = indices - .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None) - .get_type() - .get_element_type(); - let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else { - panic!("Expected list[int32] but got {indices_elem_ty}") - }; - assert_eq!( - indices_elem_ty.get_bit_width(), - 32, - "Expected list[int32] but got list[int{}]", - indices_elem_ty.get_bit_width() - ); - - let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); - - unsafe { - ctx.builder - .build_in_bounds_gep( - self.base_ptr(ctx, generator), - &[index], - name.unwrap_or_default(), - ) - .unwrap() - } - } - - fn ptr_offset( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut G, - indices: &Index, - name: Option<&str>, - ) -> PointerValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let indices_size = indices.size(ctx, generator); - let nidx_leq_ndims = ctx - .builder - .build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "") - .unwrap(); - ctx.make_assert( - generator, - nidx_leq_ndims, - "0:IndexError", - "invalid index to scalar variable", - [None, None, None], - ctx.current_loc, - ); - - let indices_len = indices.size(ctx, generator); - let ndarray_len = self.0.load_ndims(ctx); - let len = call_int_umin(ctx, indices_len, ndarray_len, None); - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (len, false), - |generator, ctx, _, i| { - let (dim_idx, dim_sz) = unsafe { - ( - indices.get_unchecked(ctx, generator, &i, None).into_int_value(), - self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None), - ) - }; - let dim_idx = ctx - .builder - .build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "") - .unwrap(); - - let dim_lt = - ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap(); - - ctx.make_assert( - generator, - dim_lt, - "0:IndexError", - "index {0} is out of bounds for axis 0 with size {1}", - [Some(dim_idx), Some(dim_sz), None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) } - } -} - -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ -} -impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index> - for NDArrayDataProxy<'ctx, '_> -{ -} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 22c2b53d..68a64b53 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -3,7 +3,7 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicTypeEnum, IntType}, + types::BasicTypeEnum, values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, IntValue}, AddressSpace, IntPredicate, }; @@ -12,18 +12,13 @@ use itertools::Either; use nac3parser::ast::Expr; use super::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, - TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, - }, - llvm_intrinsics, + classes::{ArrayLikeValue, ListValue}, macros::codegen_unreachable, model::{function::FnCall, *}, object::{ list::List, ndarray::{broadcast::ShapeEntry, indexing::NDIndex, nditer::NDIter, NDArray}, }, - stmt::gen_for_callback_incrementing, CodeGenContext, CodeGenerator, }; use crate::{symbol_resolver::SymbolResolver, typecheck::typedef::Type}; @@ -589,373 +584,6 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo .unwrap() } -/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the -/// calculated total size. -/// -/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. -/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension -/// respectively. -pub fn call_ndarray_calc_size<'ctx, G, Dims>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - dims: &Dims, - (begin, end): (Option>, Option>), -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Dims: ArrayLikeIndexer<'ctx>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_size", - 64 => "__nac3_ndarray_calc_size64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_size_fn_t = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], - false, - ); - let ndarray_calc_size_fn = - ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { - ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) - }); - - let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); - let end = end.unwrap_or_else(|| dims.size(ctx, generator)); - ctx.builder - .build_call( - ndarray_calc_size_fn, - &[ - dims.base_ptr(ctx, generator).into(), - dims.size(ctx, generator).into(), - begin.into(), - end.into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap() -} - -/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] -/// containing `i32` indices of the flattened index. -/// -/// * `index` - The index to compute the multidimensional index for. -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &mut CodeGenContext<'ctx, '_>, - index: IntValue<'ctx>, - ndarray: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_nd_indices", - 64 => "__nac3_ndarray_calc_nd_indices64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_nd_indices_fn = - ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); - - ctx.builder - .build_call( - ndarray_calc_nd_indices_fn, - &[ - index.into(), - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Indices, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Indices: ArrayLikeIndexer<'ctx>, -{ - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - debug_assert_eq!( - IntType::try_from(indices.element_type(ctx, generator)) - .map(IntType::get_bit_width) - .unwrap_or_default(), - llvm_i32.get_bit_width(), - "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" - ); - debug_assert_eq!( - indices.size(ctx, generator).get_type().get_bit_width(), - llvm_usize.get_bit_width(), - "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" - ); - - let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_flatten_index", - 64 => "__nac3_ndarray_flatten_index64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_flatten_index_fn = - ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], - false, - ); - - ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) - }); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - let ndarray_dims = ndarray.dim_sizes(); - - let index = ctx - .builder - .build_call( - ndarray_flatten_index_fn, - &[ - ndarray_dims.base_ptr(ctx, generator).into(), - ndarray_num_dims.into(), - indices.base_ptr(ctx, generator).into(), - indices.size(ctx, generator).into(), - ], - "", - ) - .map(CallSiteValue::try_as_basic_value) - .map(|v| v.map_left(BasicValueEnum::into_int_value)) - .map(Either::unwrap_left) - .unwrap(); - - index -} - -/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the -/// multidimensional index. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `indices` - The multidimensional index to compute the flattened index for. -pub fn call_ndarray_flatten_index<'ctx, G, Index>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - indices: &Index, -) -> IntValue<'ctx> -where - G: CodeGenerator + ?Sized, - Index: ArrayLikeIndexer<'ctx>, -{ - call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of -/// dimension and size of each dimension of the resultant `ndarray`. -pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast", - 64 => "__nac3_ndarray_calc_broadcast64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[ - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - llvm_usize.into(), - llvm_pusize.into(), - ], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (min_ndims, false), - |generator, ctx, _, idx| { - let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); - let (lhs_dim_sz, rhs_dim_sz) = unsafe { - ( - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), - ) - }; - - let llvm_usize_const_one = llvm_usize.const_int(1, false); - let lhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let rhs_eqz = ctx - .builder - .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") - .unwrap(); - let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); - - let lhs_eq_rhs = ctx - .builder - .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") - .unwrap(); - - let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); - - ctx.make_assert( - generator, - is_compatible, - "0:ValueError", - "operands could not be broadcast together", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); - let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); - let rhs_ndims = rhs.load_ndims(ctx); - let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); - let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[ - lhs_dims.into(), - lhs_ndims.into(), - rhs_dims.into(), - rhs_ndims.into(), - out_dims.base_ptr(ctx, generator).into(), - ], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - out_dims, - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - -/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] -/// containing the indices used for accessing `array` corresponding to the index of the broadcasted -/// array `broadcast_idx`. -pub fn call_ndarray_calc_broadcast_index< - 'ctx, - G: CodeGenerator + ?Sized, - BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, ->( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - array: NDArrayValue<'ctx>, - broadcast_idx: &BroadcastIdx, -) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_calc_broadcast_idx", - 64 => "__nac3_ndarray_calc_broadcast_idx64", - bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw), - }; - let ndarray_calc_broadcast_fn = - ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { - let fn_type = llvm_usize.fn_type( - &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], - false, - ); - - ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) - }); - - let broadcast_size = broadcast_idx.size(ctx, generator); - let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); - - let array_dims = array.dim_sizes().base_ptr(ctx, generator); - let array_ndims = array.load_ndims(ctx); - let broadcast_idx_ptr = unsafe { - broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder - .build_call( - ndarray_calc_broadcast_fn, - &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], - "", - ) - .unwrap(); - - TypedArrayLikeAdapter::from( - ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), - Box::new(|_, v| v.into_int_value()), - Box::new(|_, v| v.into()), - ) -} - // When [`TypeContext::size_type`] is 32-bits, the function name is "{fn_name}". // When [`TypeContext::size_type`] is 64-bits, the function name is "{fn_name}64". #[must_use] diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index e9f5dbd6..cbfdf51f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,1441 +1,26 @@ use inkwell::{ - types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType}, - values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, + values::{BasicValue, BasicValueEnum, PointerValue}, + IntPredicate, }; use nac3parser::ast::StrRef; use super::{ - classes::{ - ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, - ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, - TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, - }, - irrt::{ - calculate_len_for_slice_range, call_ndarray_calc_broadcast, - call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, - }, - llvm_intrinsics::{self, call_memcpy_generic}, macros::codegen_unreachable, model::*, object::{ any::AnyObject, ndarray::{nditer::NDIterHandle, shape_util::parse_numpy_int_sequence, NDArrayObject}, }, - stmt::{ - gen_for_callback, gen_for_callback_incrementing, gen_for_range_callback, - gen_if_else_expr_callback, - }, + stmt::gen_for_callback, CodeGenContext, CodeGenerator, }; use crate::{ symbol_resolver::ValueEnum, - toplevel::{ - helper::extract_ndims, - numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, - DefinitionId, - }, + toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys, DefinitionId}, typecheck::typedef::{FunSignature, Type}, }; -/// Creates an uninitialized `NDArray` instance. -fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> Result, String> { - let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); - - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_ndarray_t = ctx - .get_llvm_type(generator, ndarray_ty) - .into_pointer_type() - .get_element_type() - .into_struct_type(); - - let ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?; - - Ok(NDArrayValue::from_ptr_val(ndarray, llvm_usize, None)) -} - -/// Creates an `NDArray` instance from a dynamic shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`. -/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. -/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. -fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - shape: &V, - shape_len_fn: LenFn, - shape_data_fn: DataFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - LenFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, - DataFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &V, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - // Assert that all dimensions are non-negative - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - - let shape_dim_gez = ctx - .builder - .build_int_compare( - IntPredicate::SGE, - shape_dim, - shape_dim.get_type().const_zero(), - "", - ) - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow dim_sz > u32_MAX - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; - - let num_dims = shape_len_fn(generator, ctx, shape)?; - ndarray.store_ndims(ctx, generator, num_dims); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - // Copy the dimension sizes from shape to ndarray.dims - let shape_len = shape_len_fn(generator, ctx, shape)?; - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (shape_len, false), - |generator, ctx, _, i| { - let shape_dim = shape_data_fn(generator, ctx, shape, i)?; - debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - - let ndarray_pdim = - unsafe { ndarray.dim_sizes().ptr_offset_unchecked(ctx, generator, &i, None) }; - - ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); - - Ok(ndarray) -} - -/// Creates an `NDArray` instance from a constant shape. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The shape of the `NDArray`, represented am array of [`IntValue`]s. -pub fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: &[IntValue<'ctx>], -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - for &shape_dim in shape { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let shape_dim_gez = ctx - .builder - .build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - // TODO: Disallow dim_sz > u32_MAX - } - - let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; - - let num_dims = llvm_usize.const_int(shape.len() as u64, false); - ndarray.store_ndims(ctx, generator, num_dims); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims); - - for (i, &shape_dim) in shape.iter().enumerate() { - let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap(); - let ndarray_dim = unsafe { - ndarray.dim_sizes().ptr_offset_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, true), - None, - ) - }; - - ctx.builder.build_store(ndarray_dim, shape_dim).unwrap(); - } - - let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray); - - Ok(ndarray) -} - -/// Initializes the `data` field of [`NDArrayValue`] based on the `ndims` and `dim_sz` fields. -fn ndarray_init_data<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - ndarray: NDArrayValue<'ctx>, -) -> NDArrayValue<'ctx> { - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); - assert!(llvm_ndarray_data_t.is_sized()); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - - ndarray -} - -fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i32_type().const_zero().into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - ctx.ctx.i64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_zero().into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "").into() - } else { - codegen_unreachable!(ctx) - } -} - -fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, -) -> BasicValueEnum<'ctx> { - if [ctx.primitives.int32, ctx.primitives.uint32] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); - ctx.ctx.i32_type().const_int(1, is_signed).into() - } else if [ctx.primitives.int64, ctx.primitives.uint64] - .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) - { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); - ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { - ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { - ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { - ctx.gen_string(generator, "1").into() - } else { - codegen_unreachable!(ctx) - } -} - -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -/// -/// ### Notes on `shape` -/// -/// Just like numpy, the `shape` argument can be: -/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` -/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` -/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` -/// -/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to -/// learn how `shape` gets from being a Python user expression to here. -fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match shape { - BasicValueEnum::PointerValue(shape_list_ptr) - if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() => - { - // 1. A list of ints; e.g., `np.empty([600, 800, 3])` - - let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None); - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &shape_list, - |_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)), - |generator, ctx, shape_list, idx| { - Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value()) - }, - ) - } - BasicValueEnum::StructValue(shape_tuple) => { - // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` - // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. - - // Get the length/size of the tuple, which also happens to be the value of `ndims`. - let ndims = shape_tuple.get_type().count_fields(); - - let mut shape = Vec::with_capacity(ndims as usize); - for dim_i in 0..ndims { - let dim = ctx - .builder - .build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str()) - .unwrap() - .into_int_value(); - - shape.push(dim); - } - create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice()) - } - BasicValueEnum::IntValue(shape_int) => { - // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` - - create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int]) - } - _ => codegen_unreachable!(ctx), - } -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as -/// its input. -fn ndarray_fill_flattened<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - IntValue<'ctx>, - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - &ndarray.dim_sizes().as_slice_value(ctx, generator), - (None, None), - ); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_zero(), - (ndarray_num_elems, false), - |generator, ctx, _, i| { - let elem = unsafe { ndarray.data().ptr_offset_unchecked(ctx, generator, &i, None) }; - - let value = value_fn(generator, ctx, i)?; - ctx.builder.build_store(elem, value).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) -} - -/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices -/// as its input. -fn ndarray_fill_indexed<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - ndarray: NDArrayValue<'ctx>, - value_fn: ValueFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - &TypedArrayLikeAdapter<'ctx, IntValue<'ctx>>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, idx| { - let indices = call_ndarray_calc_nd_indices(generator, ctx, idx, ndarray); - - value_fn(generator, ctx, &indices) - }) -} - -fn ndarray_fill_mapping<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - src: NDArrayValue<'ctx>, - dest: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result<(), String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - ndarray_fill_flattened(generator, ctx, dest, |generator, ctx, i| { - let elem = unsafe { src.data().get_unchecked(ctx, generator, &i, None) }; - - map_fn(generator, ctx, elem) - }) -} - -/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of -/// the target `ndarray`. -fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - target: NDArrayValue<'ctx>, - source: NDArrayValue<'ctx>, -) { - let array_ndims = source.load_ndims(ctx); - let broadcast_size = target.load_ndims(ctx); - - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), - "0:ValueError", - "operands cannot be broadcast together", - [None, None, None], - ctx.current_loc, - ); -} - -/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value -/// with broadcast-compatible shapes. -fn ndarray_broadcast_fill<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - res: NDArrayValue<'ctx>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - // Assert that all ndarray operands are broadcastable to the target size - if !lhs_scalar { - let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); - } - - if !rhs_scalar { - let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); - ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); - } - - ndarray_fill_indexed(generator, ctx, res, |generator, ctx, idx| { - let lhs_elem = if lhs_scalar { - lhs_val - } else { - let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx); - - unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) } - }; - - let rhs_elem = if rhs_scalar { - rhs_val - } else { - let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); - let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx); - - unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) } - }; - - value_fn(generator, ctx, (lhs_elem, rhs_elem)) - })?; - - Ok(res) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_zero_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, -) -> Result, String> { - let supported_types = [ - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ctx.primitives.float, - ctx.primitives.bool, - ctx.primitives.str, - ]; - assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); - - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = ndarray_one_value(generator, ctx, elem_ty); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.full`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `shape` - The `shape` parameter used to construct the `NDArray`. -fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - shape: BasicValueEnum<'ctx>, - fill_value: BasicValueEnum<'ctx>, -) -> Result, String> { - let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; - ndarray_fill_flattened(generator, ctx, ndarray, |generator, ctx, _| { - let value = if fill_value.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - fill_value.into_pointer_value(), - fill_value.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if fill_value.is_int_value() || fill_value.is_float_value() { - fill_value - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - })?; - - Ok(ndarray) -} - -/// Returns the number of dimensions for a multidimensional list as an [`IntValue`]. -fn llvm_ndlist_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &G, - ctx: &CodeGenContext<'ctx, '_>, - ty: PointerType<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_ty = ListType::from_type(ty, llvm_usize); - let list_elem_ty = list_ty.element_type(); - - let ndims = llvm_usize.const_int(1, false); - match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { - ndims.const_add(llvm_ndlist_get_ndims(generator, ctx, ptr_ty)) - } - - AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { - todo!("Getting ndims for list[ndarray] not supported") - } - - _ => ndims, - } -} - -/// Returns the number of dimensions for an array-like object as an [`IntValue`]. -fn llvm_arraylike_get_ndims<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - value: BasicValueEnum<'ctx>, -) -> IntValue<'ctx> { - let llvm_usize = generator.get_size_type(ctx.ctx); - - match value { - BasicValueEnum::PointerValue(v) if NDArrayValue::is_instance(v, llvm_usize).is_ok() => { - NDArrayValue::from_ptr_val(v, llvm_usize, None).load_ndims(ctx) - } - - BasicValueEnum::PointerValue(v) if ListValue::is_instance(v, llvm_usize).is_ok() => { - llvm_ndlist_get_ndims(generator, ctx, v.get_type()) - } - - _ => llvm_usize.const_zero(), - } -} - -/// Flattens and copies the values from a multidimensional list into an [`NDArrayValue`]. -fn ndarray_from_ndlist_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - src_lst: ListValue<'ctx>, - dim: u64, -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let list_elem_ty = src_lst.get_type().element_type(); - - match list_elem_ty { - AnyTypeEnum::PointerType(ptr_ty) if ListType::is_type(ptr_ty, llvm_usize).is_ok() => { - // The stride of elements in this dimension, i.e. the number of elements between arr[i] - // and arr[i + 1] in this dimension - let stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, ctx| Ok(src_lst.load_size(ctx, None)), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, i| { - let offset = ctx.builder.build_int_mul(stride, i, "").unwrap(); - - let dst_ptr = - unsafe { ctx.builder.build_gep(dst_slice_ptr, &[offset], "").unwrap() }; - - let nested_lst_elem = ListValue::from_ptr_val( - unsafe { src_lst.data().get_unchecked(ctx, generator, &i, None) } - .into_pointer_value(), - llvm_usize, - None, - ); - - ndarray_from_ndlist_impl( - generator, - ctx, - elem_ty, - (dst_arr, dst_ptr), - nested_lst_elem, - dim + 1, - )?; - - Ok(()) - }, - )?; - } - - AnyTypeEnum::PointerType(ptr_ty) if NDArrayType::is_type(ptr_ty, llvm_usize).is_ok() => { - todo!("Not implemented for list[ndarray]") - } - - _ => { - let lst_len = src_lst.load_size(ctx, None); - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - let sizeof_elem = ctx.builder.build_int_cast(sizeof_elem, llvm_usize, "").unwrap(); - - let cpy_len = ctx - .builder - .build_int_mul( - ctx.builder.build_int_z_extend_or_bit_cast(lst_len, llvm_usize, "").unwrap(), - sizeof_elem, - "", - ) - .unwrap(); - - call_memcpy_generic( - ctx, - dst_slice_ptr, - src_lst.data().base_ptr(ctx, generator), - cpy_len, - llvm_i1.const_zero(), - ); - } - } - - Ok(()) -} - -/// LLVM-typed implementation for `ndarray.array`. -fn call_ndarray_array_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - object: BasicValueEnum<'ctx>, - copy: IntValue<'ctx>, - ndmin: IntValue<'ctx>, -) -> Result, String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndmin = ctx.builder.build_int_z_extend_or_bit_cast(ndmin, llvm_usize, "").unwrap(); - - // TODO(Derppening): Add assertions for sizes of different dimensions - - // object is not a pointer - 0-dim NDArray - if !object.is_pointer_value() { - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[])?; - - unsafe { - ndarray.data().set_unchecked(ctx, generator, &llvm_usize.const_zero(), object); - } - - return Ok(ndarray); - } - - let object = object.into_pointer_value(); - - // object is an NDArray instance - copy object unless copy=0 && ndmin < object.ndims - if NDArrayValue::is_instance(object, llvm_usize).is_ok() { - let object = NDArrayValue::from_ptr_val(object, llvm_usize, None); - - let ndarray = gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - let copy_nez = ctx - .builder - .build_int_compare(IntPredicate::NE, copy, llvm_i1.const_zero(), "") - .unwrap(); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx.builder.build_and(copy_nez, ndmin_gt_ndims, "").unwrap()) - }, - |generator, ctx| { - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |_, ctx, object| { - let ndims = object.load_ndims(ctx); - let ndmin_gt_ndims = ctx - .builder - .build_int_compare(IntPredicate::UGT, ndmin, object.load_ndims(ctx), "") - .unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - let ndims = object.load_ndims(ctx); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - // The number of dimensions to prepend 1's to - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::UGE, idx, offset, "") - .unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |_, ctx| Ok(Some(ctx.builder.build_int_sub(idx, offset, "").unwrap())), - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_sliced_copyto_impl( - generator, - ctx, - elem_ty, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (object, object.data().base_ptr(ctx, generator)), - 0, - &[], - )?; - - Ok(Some(ndarray.as_base_value())) - }, - |_, _| Ok(Some(object.as_base_value())), - )?; - - return Ok(NDArrayValue::from_ptr_val( - ndarray.map(BasicValueEnum::into_pointer_value).unwrap(), - llvm_usize, - None, - )); - } - - // Remaining case: TList - assert!(ListValue::is_instance(object, llvm_usize).is_ok()); - let object = ListValue::from_ptr_val(object, llvm_usize, None); - - // The number of dimensions to prepend 1's to - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin = llvm_intrinsics::call_int_umax(ctx, ndims, ndmin, None); - let offset = ctx.builder.build_int_sub(ndmin, ndims, "").unwrap(); - - let ndarray = create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &object, - |generator, ctx, object| { - let ndims = llvm_ndlist_get_ndims(generator, ctx, object.as_base_value().get_type()); - let ndmin_gt_ndims = - ctx.builder.build_int_compare(IntPredicate::UGT, ndmin, ndims, "").unwrap(); - - Ok(ctx - .builder - .build_select(ndmin_gt_ndims, ndmin, ndims, "") - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - |generator, ctx, object, idx| { - Ok(gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, idx, offset, "").unwrap()) - }, - |_, _| Ok(Some(llvm_usize.const_int(1, false))), - |generator, ctx| { - let make_llvm_list = |elem_ty: BasicTypeEnum<'ctx>| { - ctx.ctx.struct_type( - &[elem_ty.ptr_type(AddressSpace::default()).into(), llvm_usize.into()], - false, - ) - }; - - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_list_i8 = make_llvm_list(llvm_i8.into()); - let llvm_plist_i8 = llvm_list_i8.ptr_type(AddressSpace::default()); - - // Cast list to { i8*, usize } since we only care about the size - let lst = generator - .gen_var_alloc( - ctx, - ListType::new(generator, ctx.ctx, llvm_i8.into()).as_base_type().into(), - None, - ) - .unwrap(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(object.as_base_value(), llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - let stop = ctx.builder.build_int_sub(idx, offset, "").unwrap(); - gen_for_range_callback( - generator, - ctx, - None, - true, - |_, _| Ok(llvm_usize.const_zero()), - (|_, _| Ok(stop), false), - |_, _| Ok(llvm_usize.const_int(1, false)), - |generator, ctx, _, _| { - let plist_plist_i8 = make_llvm_list(llvm_plist_i8.into()) - .ptr_type(AddressSpace::default()); - - let this_dim = ctx - .builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .map(|v| ctx.builder.build_bit_cast(v, plist_plist_i8, "").unwrap()) - .map(BasicValueEnum::into_pointer_value) - .unwrap(); - let this_dim = ListValue::from_ptr_val(this_dim, llvm_usize, None); - - // TODO: Assert this_dim.sz != 0 - - let next_dim = unsafe { - this_dim.data().get_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - } - .into_pointer_value(); - ctx.builder - .build_store( - lst, - ctx.builder - .build_bit_cast(next_dim, llvm_plist_i8, "") - .unwrap(), - ) - .unwrap(); - - Ok(()) - }, - )?; - - let lst = ListValue::from_ptr_val( - ctx.builder - .build_load(lst, "") - .map(BasicValueEnum::into_pointer_value) - .unwrap(), - llvm_usize, - None, - ); - - Ok(Some(lst.load_size(ctx, None))) - }, - )? - .map(BasicValueEnum::into_int_value) - .unwrap()) - }, - )?; - - ndarray_from_ndlist_impl( - generator, - ctx, - elem_ty, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - object, - 0, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - nrows: IntValue<'ctx>, - ncols: IntValue<'ctx>, - offset: IntValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "").unwrap(); - let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "").unwrap(); - - let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, &[nrows, ncols])?; - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, indices| { - let (row, col) = unsafe { - ( - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None), - indices.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None), - ) - }; - - let col_with_offset = ctx - .builder - .build_int_add( - col, - ctx.builder.build_int_s_extend_or_bit_cast(offset, llvm_i32, "").unwrap(), - "", - ) - .unwrap(); - let is_on_diag = - ctx.builder.build_int_compare(IntPredicate::EQ, row, col_with_offset, "").unwrap(); - - let zero = ndarray_zero_value(generator, ctx, elem_ty); - let one = ndarray_one_value(generator, ctx, elem_ty); - - let value = ctx.builder.build_select(is_on_diag, one, zero, "").unwrap(); - - Ok(value) - })?; - - Ok(ndarray) -} - -/// Copies a slice of an [`NDArrayValue`] to another. -/// -/// - `dst_arr`: The [`NDArrayValue`] instance of the destination array. The `ndims` and `dim_sz` -/// fields should be populated before calling this function. -/// - `dst_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the destination array. -/// - `src_arr`: The [`NDArrayValue`] instance of the source array. -/// - `src_slice_ptr`: The [`PointerValue`] to the first element of the currently processing -/// dimensional slice in the source array. -/// - `dim`: The index of the currently processing dimension. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be non-negative indices. -fn ndarray_sliced_copyto_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - (dst_arr, dst_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - (src_arr, src_slice_ptr): (NDArrayValue<'ctx>, PointerValue<'ctx>), - dim: u64, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result<(), String> { - let llvm_i1 = ctx.ctx.bool_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - // If there are no (remaining) slice expressions, memcpy the entire dimension - if slices.is_empty() { - let sizeof_elem = ctx.get_llvm_type(generator, elem_ty).size_of().unwrap(); - - let stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim, false)), None), - ); - let stride = - ctx.builder.build_int_z_extend_or_bit_cast(stride, sizeof_elem.get_type(), "").unwrap(); - - let cpy_len = ctx.builder.build_int_mul(stride, sizeof_elem, "").unwrap(); - - call_memcpy_generic(ctx, dst_slice_ptr, src_slice_ptr, cpy_len, llvm_i1.const_zero()); - - return Ok(()); - } - - // The stride of elements in this dimension, i.e. the number of elements between arr[i] and - // arr[i + 1] in this dimension - let src_stride = call_ndarray_calc_size( - generator, - ctx, - &src_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - let dst_stride = call_ndarray_calc_size( - generator, - ctx, - &dst_arr.dim_sizes(), - (Some(llvm_usize.const_int(dim + 1, false)), None), - ); - - let (start, stop, step) = slices[0]; - let start = ctx.builder.build_int_s_extend_or_bit_cast(start, llvm_usize, "").unwrap(); - let stop = ctx.builder.build_int_s_extend_or_bit_cast(stop, llvm_usize, "").unwrap(); - let step = ctx.builder.build_int_s_extend_or_bit_cast(step, llvm_usize, "").unwrap(); - - let dst_i_addr = generator.gen_var_alloc(ctx, start.get_type().into(), None).unwrap(); - ctx.builder.build_store(dst_i_addr, start.get_type().const_zero()).unwrap(); - - gen_for_range_callback( - generator, - ctx, - None, - false, - |_, _| Ok(start), - (|_, _| Ok(stop), true), - |_, _| Ok(step), - |generator, ctx, _, src_i| { - // Calculate the offset of the active slice - let src_data_offset = ctx.builder.build_int_mul(src_stride, src_i, "").unwrap(); - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_data_offset = ctx.builder.build_int_mul(dst_stride, dst_i, "").unwrap(); - - let (src_ptr, dst_ptr) = unsafe { - ( - ctx.builder.build_gep(src_slice_ptr, &[src_data_offset], "").unwrap(), - ctx.builder.build_gep(dst_slice_ptr, &[dst_data_offset], "").unwrap(), - ) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - elem_ty, - (dst_arr, dst_ptr), - (src_arr, src_ptr), - dim + 1, - &slices[1..], - )?; - - let dst_i = - ctx.builder.build_load(dst_i_addr, "").map(BasicValueEnum::into_int_value).unwrap(); - let dst_i_add1 = - ctx.builder.build_int_add(dst_i, llvm_usize.const_int(1, false), "").unwrap(); - ctx.builder.build_store(dst_i_addr, dst_i_add1).unwrap(); - - Ok(()) - }, - )?; - - Ok(()) -} - -/// Copies a [`NDArrayValue`] using slices. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// - `slices`: List of all slices, with the first element corresponding to the slice applicable to -/// this dimension. The `start`/`stop` values of each slice must be positive indices. -pub fn ndarray_sliced_copy<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, - slices: &[(IntValue<'ctx>, IntValue<'ctx>, IntValue<'ctx>)], -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let ndarray = if slices.is_empty() { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &this, - |_, ctx, shape| Ok(shape.load_ndims(ctx)), - |generator, ctx, shape, idx| unsafe { - Ok(shape.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - }, - )? - } else { - let ndarray = create_ndarray_uninitialized(generator, ctx, elem_ty)?; - ndarray.store_ndims(ctx, generator, this.load_ndims(ctx)); - - let ndims = this.load_ndims(ctx); - ndarray.create_dim_sizes(ctx, llvm_usize, ndims); - - // Populate the first slices.len() dimensions by computing the size of each dim slice - for (i, (start, stop, step)) in slices.iter().enumerate() { - // HACK: workaround calculate_len_for_slice_range requiring exclusive stop - let stop = ctx - .builder - .build_select( - ctx.builder - .build_int_compare( - IntPredicate::SLT, - *step, - llvm_i32.const_zero(), - "is_neg", - ) - .unwrap(), - ctx.builder - .build_int_sub(*stop, llvm_i32.const_int(1, true), "e_min_one") - .unwrap(), - ctx.builder - .build_int_add(*stop, llvm_i32.const_int(1, true), "e_add_one") - .unwrap(), - "final_e", - ) - .map(BasicValueEnum::into_int_value) - .unwrap(); - - let slice_len = calculate_len_for_slice_range(generator, ctx, *start, stop, *step); - let slice_len = - ctx.builder.build_int_z_extend_or_bit_cast(slice_len, llvm_usize, "").unwrap(); - - unsafe { - ndarray.dim_sizes().set_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(i as u64, false), - slice_len, - ); - } - } - - // Populate the rest by directly copying the dim size from the source array - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_usize.const_int(slices.len() as u64, false), - (this.load_ndims(ctx), false), - |generator, ctx, _, idx| { - unsafe { - let dim_sz = this.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None); - ndarray.dim_sizes().set_typed_unchecked(ctx, generator, &idx, dim_sz); - } - - Ok(()) - }, - llvm_usize.const_int(1, false), - ) - .unwrap(); - - ndarray_init_data(generator, ctx, elem_ty, ndarray) - }; - - ndarray_sliced_copyto_impl( - generator, - ctx, - elem_ty, - (ndarray, ndarray.data().base_ptr(ctx, generator)), - (this, this.data().base_ptr(ctx, generator)), - 0, - slices, - )?; - - Ok(ndarray) -} - -/// LLVM-typed implementation for generating the implementation for `ndarray.copy`. -/// -/// * `elem_ty` - The element type of the `NDArray`. -fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - this: NDArrayValue<'ctx>, -) -> Result, String> { - ndarray_sliced_copy(generator, ctx, elem_ty, this, &[]) -} - -pub fn ndarray_elementwise_unaryop_impl<'ctx, 'a, G, MapFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - operand: NDArrayValue<'ctx>, - map_fn: MapFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - MapFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - BasicValueEnum<'ctx>, - ) -> Result, String>, -{ - let res = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &operand, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - }); - - ndarray_fill_mapping(generator, ctx, operand, res, |generator, ctx, elem| { - map_fn(generator, ctx, elem) - })?; - - Ok(res) -} - -/// LLVM-typed implementation for computing elementwise binary operations on two input operands. -/// -/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output -/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. -/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the -/// `value_fn` arguments tuple for all output elements. -/// -/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` -/// (which would be accessed by its broadcast index) or as a scalar value (which would be -/// broadcast to all elements). -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -/// * `value_fn` - Function mapping the two input elements into the result. -/// -/// # Panic -/// -/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. -pub fn ndarray_elementwise_binop_impl<'ctx, 'a, G, ValueFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - elem_ty: Type, - res: Option>, - lhs: (BasicValueEnum<'ctx>, bool), - rhs: (BasicValueEnum<'ctx>, bool), - value_fn: ValueFn, -) -> Result, String> -where - G: CodeGenerator + ?Sized, - ValueFn: Fn( - &mut G, - &mut CodeGenContext<'ctx, 'a>, - (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>), - ) -> Result, String>, -{ - let llvm_usize = generator.get_size_type(ctx.ctx); - - let (lhs_val, lhs_scalar) = lhs; - let (rhs_val, rhs_scalar) = rhs; - - assert!( - !(lhs_scalar && rhs_scalar), - "One of the operands must be a ndarray instance: `{}`, `{}`", - lhs_val.get_type(), - rhs_val.get_type() - ); - - let ndarray = res.unwrap_or_else(|| { - if lhs_scalar && rhs_scalar { - let lhs_val = - NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); - let rhs_val = - NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); - - let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray_dims, - |generator, ctx, v| Ok(v.size(ctx, generator)), - |generator, ctx, v, idx| unsafe { - Ok(v.get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } else { - let ndarray = NDArrayValue::from_ptr_val( - if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), - llvm_usize, - None, - ); - - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &ndarray, - |_, ctx, v| Ok(v.load_ndims(ctx)), - |generator, ctx, v, idx| unsafe { - Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None)) - }, - ) - .unwrap() - } - }); - - ndarray_broadcast_fill(generator, ctx, ndarray, lhs, rhs, |generator, ctx, elems| { - value_fn(generator, ctx, elems) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 2bd02a71..d075bb89 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -16,7 +16,7 @@ use nac3parser::{ use parking_lot::RwLock; use super::{ - classes::{ListType, NDArrayType, ProxyType, RangeType}, + classes::{ListType, ProxyType, RangeType}, concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask, CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry, @@ -462,15 +462,3 @@ fn test_classes_range_type_new() { let llvm_range = RangeType::new(&ctx); assert!(RangeType::is_type(llvm_range.as_base_type()).is_ok()); } - -#[test] -fn test_classes_ndarray_type_new() { - let ctx = inkwell::context::Context::create(); - let generator = DefaultCodeGenerator::new(String::new(), 64); - - let llvm_i32 = ctx.i32_type(); - let llvm_usize = generator.get_size_type(&ctx); - - let llvm_ndarray = NDArrayType::new(&generator, &ctx, llvm_i32.into()); - assert!(NDArrayType::is_type(llvm_ndarray.as_base_type(), llvm_usize).is_ok()); -} -- 2.44.2