From cc538d221ac61833a97ce18346a274702fe8631c Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 19 Feb 2024 17:10:18 +0800 Subject: [PATCH] core: Implement codegen for indexing into ndarray --- nac3core/src/codegen/expr.rs | 227 ++++++++++++++++++++++++++++++++++- 1 file changed, 224 insertions(+), 3 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 58f090f59..58a6d8d2d 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2,7 +2,7 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use crate::{ codegen::{ - classes::{ListValue, RangeValue}, + classes::{ListValue, NDArrayValue, RangeValue}, concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore}, gen_in_range_check, get_llvm_type, @@ -1190,6 +1190,213 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( } } +/// Generates code for a subscript expression on an `ndarray`. +/// +/// * `ty` - The `Type` of the `NDArray` elements. +/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`. +/// * `v` - The `NDArray` value. +/// * `slice` - The slice expression used to subscript into the `ndarray`. +fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ndims: Type, + v: NDArrayValue<'ctx>, + slice: &Expr>, +) -> Result>, String> { + let llvm_void = ctx.ctx.void_type(); + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + + let TypeEnum::TLiteral { values, .. } = &*ctx.unifier.get_ty_immutable(ndims) else { + unreachable!() + }; + + let ndims = values.iter() + .map(|ndim| match *ndim { + SymbolValue::U64(v) => Ok(v), + SymbolValue::U32(v) => Ok(v as u64), + SymbolValue::I32(v) => u64::try_from(v) + .map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")), + SymbolValue::I64(v) => u64::try_from(v) + .map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")), + _ => unreachable!(), + }) + .collect::, _>>()?; + + assert!(!ndims.is_empty()); + + let ndarray_ty_enum = TypeEnum::TNDArray { + ty, + ndims: ctx.unifier.get_fresh_literal( + ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), + None, + ), + }; + let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); + let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); + let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); + + // Check that len is non-zero + let len = v.load_ndims(ctx); + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::SGT, len, llvm_usize.const_zero(), ""), + "0:IndexError", + "too many indices for array: array is {0}-dimensional but 1 were indexed", + [Some(len), None, None], + slice.location, + ); + + if ndims.len() == 1 && ndims[0] == 1 { + // Accessing an element from a 1-dimensional `ndarray` + + if let ExprKind::Slice { .. } = &slice.node { + return Err(String::from("subscript operator for ndarray not implemented")) + } + + let index = if let Some(v) = generator.gen_expr(ctx, slice)? { + v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() + } else { + return Ok(None) + }; + + Ok(Some(v.get_data() + .get_const( + ctx, + generator, + ctx.ctx.i32_type().const_array(&[index]), + None, + ) + .into())) + } else { + // Accessing an element from a multi-dimensional `ndarray` + + if let ExprKind::Slice { .. } = &slice.node { + return Err(String::from("subscript operator for ndarray not implemented")) + } + + let index = if let Some(v) = generator.gen_expr(ctx, slice)? { + v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() + } else { + return Ok(None) + }; + + // Create a new array, remove the top dimension from the dimension-size-list, and copy the + // elements over + let subscripted_ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None + )?; + let ndarray = NDArrayValue::from_ptr_val( + subscripted_ndarray, + llvm_usize, + None + ); + + let num_dims = v.load_ndims(ctx); + ndarray.store_ndims( + ctx, + generator, + ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), ""), + ); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); + + let memcpy_fn_name = format!( + "llvm.memcpy.p0i8.p0i8.i{}", + generator.get_size_type(ctx.ctx).get_bit_width(), + ); + let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[ + llvm_pi8.into(), + llvm_pi8.into(), + llvm_usize.into(), + llvm_i1.into(), + ], + false, + ); + + ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None) + }); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + let v_dims_src_ptr = v.get_dims().ptr_offset( + ctx, + generator, + llvm_usize.const_int(1, false), + None, + ); + ctx.builder.build_call( + memcpy_fn, + &[ + ctx.builder.build_bitcast( + ndarray.get_dims().get_ptr(ctx), + llvm_pi8, + "", + ).into(), + ctx.builder.build_bitcast( + v_dims_src_ptr, + llvm_pi8, + "", + ).into(), + ctx.builder.build_int_mul( + ndarray_num_dims.into(), + llvm_usize.size_of(), + "", + ).into(), + llvm_i1.const_zero().into(), + ], + "", + ); + + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ndarray.load_ndims(ctx), + ndarray.get_dims().get_ptr(ctx), + ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); + + let v_data_src_ptr = v.get_data().ptr_offset_const( + ctx, + generator, + ctx.ctx.i32_type().const_array(&[index]), + None + ); + ctx.builder.build_call( + memcpy_fn, + &[ + ctx.builder.build_bitcast( + ndarray.get_data().get_ptr(ctx), + llvm_pi8, + "", + ).into(), + ctx.builder.build_bitcast( + v_data_src_ptr, + llvm_pi8, + "", + ).into(), + ctx.builder.build_int_mul( + ndarray_num_elems.into(), + llvm_ndarray_data_t.size_of().unwrap(), + "", + ).into(), + llvm_i1.const_zero().into(), + ], + "", + ); + + Ok(Some(v.get_ptr().into())) + } +} + /// See [`CodeGenerator::gen_expr`]. pub fn gen_expr<'ctx, G: CodeGenerator>( generator: &mut G, @@ -1810,8 +2017,22 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v.get_data().get(ctx, generator, index, None).into() } } - TypeEnum::TNDArray { .. } => { - return Err(String::from("subscript operator for ndarray not implemented")) + TypeEnum::TNDArray { ty, ndims } => { + let v = if let Some(v) = generator.gen_expr(ctx, value)? { + v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() + } else { + return Ok(None) + }; + let v = NDArrayValue::from_ptr_val(v, usize, None); + + return gen_ndarray_subscript_expr( + generator, + ctx, + *ty, + *ndims, + v, + &*slice, + ) } TypeEnum::TTuple { .. } => { let index: u32 =