diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index fb395ea8..7381fac7 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -129,3 +129,74 @@ int32_t __nac3_range_slice_len(const int32_t start, const int32_t end, const int DEF_SLICE_ASSIGN(uint8_t) DEF_SLICE_ASSIGN(uint32_t) DEF_SLICE_ASSIGN(uint64_t) + +int32_t __nac3_list_slice_assign_var_size( + int32_t dest_start, + int32_t dest_end, + int32_t dest_step, + uint8_t *dest_arr, + int32_t dest_arr_len, + int32_t src_start, + int32_t src_end, + int32_t src_step, + uint8_t *src_arr, + int32_t src_arr_len, + const int32_t size +) { + /* if dest_arr_len == 0, do nothing since we do not support extending list */ + if (dest_arr_len == 0) return dest_arr_len; + /* if both step is 1, memmove directly, handle the dropping of the list, and shrink size */ + if (src_step == dest_step && dest_step == 1) { + const int32_t src_len = (src_end >= src_start) ? (src_end - src_start + 1) : 0; + const int32_t dest_len = (dest_end >= dest_start) ? (dest_end - dest_start + 1) : 0; + if (src_len > 0) { + __builtin_memmove( + dest_arr + dest_start * size, + src_arr + src_start * size, + src_len * size + ); + } + if (dest_len > 0) { + /* dropping */ + __builtin_memmove( + dest_arr + (dest_start + src_len) * size, + dest_arr + (dest_end + 1) * size, + (dest_arr_len - dest_end - 1) * size + ); + } + /* shrink size */ + return dest_arr_len - (dest_len - src_len); + } + /* if two range overlaps, need alloca */ + uint8_t need_alloca = + (dest_arr == src_arr) + && !( + MAX(dest_start, dest_end) < MIN(src_start, src_end) + || MAX(src_start, src_end) < MIN(dest_start, dest_end) + ); + if (need_alloca) { + uint8_t *tmp = alloca(src_arr_len * size); + __builtin_memcpy(tmp, src_arr, src_arr_len * size); + src_arr = tmp; + } + int32_t src_ind = src_start; + int32_t dest_ind = dest_start; + for (; + (src_step > 0) ? (src_ind <= src_end) : (src_ind >= src_end); + src_ind += src_step, dest_ind += dest_step + ) { + /* memcpy for var size, cannot overlap after previous alloca */ + __builtin_memcpy(dest_arr + dest_ind * size, src_arr + src_ind * size, size); + } + /* only dest_step == 1 can we shrink the dest list. */ + /* size should be ensured prior to calling this function */ + if (dest_step == 1 && dest_end >= dest_start) { + __builtin_memmove( + dest_arr + dest_ind * size, + dest_arr + (dest_end + 1) * size, + (dest_arr_len - dest_end - 1) * size + ); + return dest_arr_len - (dest_end - dest_ind) - 1; + } + return dest_arr_len; +} \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 6891c69b..88450f97 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -239,48 +239,52 @@ pub fn list_slice_assignment<'ctx, 'a>( let int32 = ctx.ctx.i32_type(); let int32_ptr = int32.ptr_type(AddressSpace::Generic); let int64_ptr = ctx.ctx.i64_type().ptr_type(AddressSpace::Generic); - let fun_symbol = if let BasicTypeEnum::IntType(ty) = ty { - match ty.get_bit_width() { - w if w < 32 => "__nac3_list_slice_assign_uint8_t", - 32 => "__nac3_list_slice_assign_uint32_t", - 64 => "__nac3_list_slice_assign_uint64_t", + let elem_size = match ty { + BasicTypeEnum::IntType(ty) => match ty.get_bit_width() { + w if w < 32 => Some(8), + w if w == 32 || w == 64 => Some(w), _ => unreachable!(), - } - } else if ty.is_float_type() { - "__nac3_list_slice_assign_uint64_t" - } else if ty.is_pointer_type() { - match size_ty.get_bit_width() { - 32 => "__nac3_list_slice_assign_uint32_t", - 64 => "__nac3_list_slice_assign_uint64_t", + }, + BasicTypeEnum::FloatType(_) => Some(64), + BasicTypeEnum::PointerType(_) => match size_ty.get_bit_width() { + w if w == 32 || w == 64 => Some(w), _ => unreachable!(), - } - } else { - unreachable!() - }; - let elem_ptr_type = match fun_symbol { - "__nac3_list_slice_assign_uint8_t" => int8_ptr, - "__nac3_list_slice_assign_uint32_t" => int32_ptr, - "__nac3_list_slice_assign_uint64_t" => int64_ptr, + }, + BasicTypeEnum::StructType(_) => None, _ => unreachable!(), }; - let slice_assign_fun = ctx.module.get_function(fun_symbol).unwrap_or_else(|| { - let fn_t = int32.fn_type( - &[ - int32.into(), // dest start idx - int32.into(), // dest end idx - int32.into(), // dest step - elem_ptr_type.into(), // dest arr ptr - int32.into(), // dest arr len - int32.into(), // src start idx - int32.into(), // src end idx - int32.into(), // src step - elem_ptr_type.into(), // src arr ptr - int32.into(), // src arr len - ], - false, - ); - ctx.module.add_function(fun_symbol, fn_t, None) - }); + let (fun_symbol, elem_ptr_type) = match elem_size { + Some(8) => ("__nac3_list_slice_assign_uint8_t", int8_ptr), + Some(32) => ("__nac3_list_slice_assign_uint32_t", int32_ptr), + Some(64) => ("__nac3_list_slice_assign_uint64_t", int64_ptr), + _ => ("__nac3_list_slice_assign_var_size", int8_ptr), + }; + let slice_assign_fun = { + let mut ty_vec = vec![ + int32.into(), // dest start idx + int32.into(), // dest end idx + int32.into(), // dest step + elem_ptr_type.into(), // dest arr ptr + int32.into(), // dest arr len + int32.into(), // src start idx + int32.into(), // src end idx + int32.into(), // src step + elem_ptr_type.into(), // src arr ptr + int32.into(), // src arr len + ]; + ctx.module.get_function(fun_symbol).unwrap_or_else(|| { + let fn_t = int32.fn_type( + { + if fun_symbol == "__nac3_list_slice_assign_var_size" { + ty_vec.push(int32.into()); + } + ty_vec.as_slice() + }, + false, + ); + ctx.module.add_function(fun_symbol, fn_t, None) + }) + }; let zero = int32.const_zero(); let one = int32.const_int(1, false); @@ -304,27 +308,36 @@ pub fn list_slice_assignment<'ctx, 'a>( // index in bound and positive should be done // TODO: assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and // throw exception if not satisfied - let new_len = ctx - .builder - .build_call( - slice_assign_fun, - &[ - dest_idx.0.into(), // dest start idx - dest_idx.1.into(), // dest end idx - dest_idx.2.into(), // dest step - dest_arr_ptr.into(), // dest arr ptr - dest_len.into(), // dest arr len - src_idx.0.into(), // src start idx - src_idx.1.into(), // src end idx - src_idx.2.into(), // src step - src_arr_ptr.into(), // src arr ptr - src_len.into(), // src arr len - ], - "slice_assign", - ) - .try_as_basic_value() - .unwrap_left() - .into_int_value(); + let new_len = { + let mut args = vec![ + dest_idx.0.into(), // dest start idx + dest_idx.1.into(), // dest end idx + dest_idx.2.into(), // dest step + dest_arr_ptr.into(), // dest arr ptr + dest_len.into(), // dest arr len + src_idx.0.into(), // src start idx + src_idx.1.into(), // src end idx + src_idx.2.into(), // src step + src_arr_ptr.into(), // src arr ptr + src_len.into(), // src arr len + ]; + ctx.builder + .build_call( + slice_assign_fun, + { + if fun_symbol == "__nac3_list_slice_assign_var_size" { + let s = ty.into_struct_type().size_of().unwrap(); + let s = ctx.builder.build_int_truncate_or_bit_cast(s, int32, "size"); + args.push(s.into()); + } + args.as_slice() + }, + "slice_assign", + ) + .try_as_basic_value() + .unwrap_left() + .into_int_value() + }; // update length let need_update = ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update");