diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 778ad292e..e9872b0c7 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -224,24 +224,6 @@ uint64_t __nac3_ndarray_calc_size64( return num_elems; } -void __nac3_ndarray_init_dims( - uint32_t *ndarray_dims, - const int32_t *shape_data, - uint32_t shape_len -) { - __builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t)); -} - -void __nac3_ndarray_init_dims64( - uint64_t *ndarray_dims, - const int32_t *shape_data, - uint64_t shape_len -) { - for (uint64_t i = 0; i < shape_len; ++i) { - ndarray_dims[i] = (uint64_t) shape_data[i]; - } -} - void __nac3_ndarray_calc_nd_indices( uint32_t index, const uint32_t* dims, diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index fffbc426f..8caccf0ec 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -617,60 +617,6 @@ pub fn call_ndarray_calc_size<'ctx>( .unwrap() } -/// Generates a call to `__nac3_ndarray_init_dims`. -/// -/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an -/// `NDArray`. -/// * `shape` - LLVM pointer to the `shape` of the `NDArray`. This value must be the LLVM -/// representation of a `list`. -pub fn call_ndarray_init_dims<'ctx>( - generator: &dyn CodeGenerator, - ctx: &mut CodeGenContext<'ctx, '_>, - ndarray: NDArrayValue<'ctx>, - shape: ListValue<'ctx>, -) { - let llvm_void = ctx.ctx.void_type(); - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); - let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); - - let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() { - 32 => "__nac3_ndarray_init_dims", - 64 => "__nac3_ndarray_init_dims64", - bw => unreachable!("Unsupported size type bit width: {}", bw) - }; - let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| { - let fn_type = llvm_void.fn_type( - &[ - llvm_pusize.into(), - llvm_pi32.into(), - llvm_usize.into(), - ], - false, - ); - - ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None) - }); - - let ndarray_dims = ndarray.get_dims(); - let shape_data = shape.get_data(); - let ndarray_num_dims = ndarray.load_ndims(ctx); - - ctx.builder - .build_call( - ndarray_init_dims_fn, - &[ - ndarray_dims.get_ptr(ctx).into(), - shape_data.get_ptr(ctx).into(), - ndarray_num_dims.into(), - ], - "", - ) - .unwrap(); -} - /// Generates a call to `__nac3_ndarray_calc_nd_indices`. /// /// * `index` - The index to compute the multidimensional index for. diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index bc992eff8..26c8044db 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -10,7 +10,6 @@ use crate::{ irrt::{ call_ndarray_calc_nd_indices, call_ndarray_calc_size, - call_ndarray_init_dims, }, llvm_intrinsics::call_memcpy_generic, stmt::gen_for_callback @@ -78,6 +77,161 @@ pub fn unpack_ndarray_tvars( .unwrap() } +/// Creates an `NDArray` instance from a dynamic shape. +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// * `shape` - The shape of the `NDArray`. +/// * `shape_len_fn` - A function that retrieves the number of dimensions from `shape`. +/// * `shape_data_fn` - A function that retrieves the size of a dimension from `shape`. +fn create_ndarray_dyn_shape<'ctx, 'a, V, LenFn, DataFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: &V, + shape_len_fn: LenFn, + shape_data_fn: DataFn, +) -> Result, String> + where + LenFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V) -> Result, String>, + DataFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, &V, IntValue<'ctx>) -> Result, String>, +{ + let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); + + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); + let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); + let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); + assert!(llvm_ndarray_data_t.is_sized()); + + // Assert that all dimensions are non-negative + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); + + Ok(i) + }, + |generator, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let shape_len = shape_len_fn(generator, ctx, shape)?; + debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width()); + + Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap()) + }, + |generator, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let shape_dim = shape_data_fn(generator, ctx, shape, i)?; + debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); + + let shape_dim_gez = ctx.builder + .build_int_compare(IntPredicate::SGE, shape_dim, shape_dim.get_type().const_zero(), "") + .unwrap(); + + ctx.make_assert( + generator, + shape_dim_gez, + "0:ValueError", + "negative dimensions not supported", + [None, None, None], + ctx.current_loc, + ); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap(); + ctx.builder.build_store(i_addr, i).unwrap(); + + Ok(()) + }, + )?; + + let ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None, + )?; + let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); + + let num_dims = shape_len_fn(generator, ctx, shape)?; + ndarray.store_ndims(ctx, generator, num_dims); + + let ndarray_num_dims = ndarray.load_ndims(ctx); + ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); + + // Copy the dimension sizes from shape to ndarray.dims + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); + + Ok(i) + }, + |generator, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let shape_len = shape_len_fn(generator, ctx, shape)?; + debug_assert!(shape_len.get_type().get_bit_width() <= llvm_usize.get_bit_width()); + + Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap()) + }, + |generator, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let shape_dim = shape_data_fn(generator, ctx, shape, i)?; + debug_assert!(shape_dim.get_type().get_bit_width() <= llvm_usize.get_bit_width()); + let shape_dim = ctx.builder + .build_int_z_extend(shape_dim, llvm_usize, "") + .unwrap(); + + let ndarray_pdim = ndarray.get_dims().ptr_offset(ctx, generator, i, None); + + ctx.builder.build_store(ndarray_pdim, shape_dim).unwrap(); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .map(BasicValueEnum::into_int_value) + .unwrap(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap(); + ctx.builder.build_store(i_addr, i).unwrap(); + + Ok(()) + }, + )?; + + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ndarray.load_ndims(ctx), + ndarray.get_dims().get_ptr(ctx), + ); + ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); + + Ok(ndarray) +} + /// Creates an `NDArray` instance from a constant shape. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -205,98 +359,18 @@ fn call_ndarray_empty_impl<'ctx>( elem_ty: Type, shape: ListValue<'ctx>, ) -> Result, String> { - let ndarray_ty = make_ndarray_ty( - &mut ctx.unifier, - &ctx.primitives, - Some(elem_ty), - None, - ); - - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); - let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); - let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); - assert!(llvm_ndarray_data_t.is_sized()); - - // Assert that all dimensions are non-negative - gen_for_callback( + create_ndarray_dyn_shape( generator, ctx, - |generator, ctx| { - let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; - ctx.builder.build_store(i, llvm_usize.const_zero()).unwrap(); - - Ok(i) + elem_ty, + &shape, + |_, ctx, shape| { + Ok(shape.load_size(ctx, None)) }, - |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let shape_len = shape.load_size(ctx, None); - - Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, shape_len, "").unwrap()) + |generator, ctx, shape, idx| { + Ok(shape.get_data().get(ctx, generator, idx, None).into_int_value()) }, - |generator, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let shape_dim = shape.get_data().get(ctx, generator, i, None).into_int_value(); - - let shape_dim_gez = ctx.builder - .build_int_compare(IntPredicate::SGE, shape_dim, llvm_i32.const_zero(), "") - .unwrap(); - - ctx.make_assert( - generator, - shape_dim_gez, - "0:ValueError", - "negative dimensions not supported", - [None, None, None], - ctx.current_loc, - ); - - Ok(()) - }, - |_, ctx, i_addr| { - let i = ctx.builder - .build_load(i_addr, "") - .map(BasicValueEnum::into_int_value) - .unwrap(); - let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "").unwrap(); - ctx.builder.build_store(i_addr, i).unwrap(); - - Ok(()) - }, - )?; - - let ndarray = generator.gen_var_alloc( - ctx, - llvm_ndarray_t.into(), - None, - )?; - let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None); - - let num_dims = shape.load_size(ctx, None); - ndarray.store_ndims(ctx, generator, num_dims); - - let ndarray_num_dims = ndarray.load_ndims(ctx); - ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims); - - call_ndarray_init_dims(generator, ctx, ndarray, shape); - - let ndarray_num_elems = call_ndarray_calc_size( - generator, - ctx, - ndarray.load_ndims(ctx), - ndarray.get_dims().get_ptr(ctx), - ); - ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems); - - Ok(ndarray) + ) } /// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as