forked from M-Labs/nac3
core: Add NDArrayValue and helper functions
This commit is contained in:
parent
148900302e
commit
8470915809
|
@ -3,7 +3,12 @@ use inkwell::{
|
|||
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
|
||||
values::{BasicValueEnum, IntValue, PointerValue},
|
||||
};
|
||||
use crate::codegen::{CodeGenContext, CodeGenerator};
|
||||
use crate::codegen::{
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||
stmt::gen_for_callback,
|
||||
};
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {}
|
||||
|
@ -380,3 +385,485 @@ impl<'ctx> RangeValue<'ctx> {
|
|||
ctx.builder.build_load(pstep, var_name.as_str()).into_int_value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
pub fn assert_is_ndarray<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
pub fn assert_is_ndarray<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) {
|
||||
if let Err(msg) = NDArrayValue::is_instance(value, llvm_usize) {
|
||||
panic!("{msg}")
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy type for accessing an `NDArray` value in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDArrayValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>);
|
||||
|
||||
impl<'ctx> NDArrayValue<'ctx> {
|
||||
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
|
||||
/// instance.
|
||||
pub fn is_instance(
|
||||
value: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
let llvm_ndarray_ty = value.get_type().get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"))
|
||||
};
|
||||
if llvm_ndarray_ty.count_fields() != 3 {
|
||||
return Err(format!("Expected 3 fields in `NDArray`, got {}", llvm_ndarray_ty.count_fields()))
|
||||
}
|
||||
|
||||
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
|
||||
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
|
||||
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"))
|
||||
};
|
||||
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
|
||||
return Err(format!("Expected {}-bit int type for `ndarray.0`, got {}-bit int",
|
||||
llvm_usize.get_bit_width(),
|
||||
ndarray_ndims_ty.get_bit_width()))
|
||||
}
|
||||
|
||||
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
|
||||
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
|
||||
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"))
|
||||
};
|
||||
let ndarray_dims = ndarray_pdims.get_element_type();
|
||||
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
|
||||
return Err(format!("Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"))
|
||||
};
|
||||
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
|
||||
return Err(format!("Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
|
||||
llvm_usize.get_bit_width(),
|
||||
ndarray_dims.get_bit_width()))
|
||||
}
|
||||
|
||||
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
|
||||
let Ok(_) = PointerType::try_from(ndarray_data_ty) else {
|
||||
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"))
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Creates an [NDArrayValue] from a [PointerValue].
|
||||
pub fn from_ptr_val(
|
||||
ptr: PointerValue<'ctx>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
name: Option<&'ctx str>,
|
||||
) -> Self {
|
||||
assert_is_ndarray(ptr, llvm_usize);
|
||||
NDArrayValue(ptr, name)
|
||||
}
|
||||
|
||||
/// Returns the underlying [PointerValue] pointing to the `NDArray` instance.
|
||||
pub fn get_ptr(&self) -> PointerValue<'ctx> {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||
fn get_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let var_name = self.1.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.0,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
var_name.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the number of dimensions `ndims` into this instance.
|
||||
pub fn store_ndims(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &dyn CodeGenerator,
|
||||
ndims: IntValue<'ctx>,
|
||||
) {
|
||||
debug_assert_eq!(ndims.get_type(), generator.get_size_type(ctx.ctx));
|
||||
|
||||
let pndims = self.get_ndims(ctx);
|
||||
ctx.builder.build_store(pndims, ndims);
|
||||
}
|
||||
|
||||
/// Returns the number of dimensions of this `NDArray` as a value.
|
||||
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
let pndims = self.get_ndims(ctx);
|
||||
ctx.builder.build_load(pndims, "").into_int_value()
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
|
||||
/// on the field.
|
||||
fn get_dims_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let var_name = self.1.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.get_ptr(),
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
var_name.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the array of dimension sizes `dims` into this instance.
|
||||
fn store_dims(&self, ctx: &CodeGenContext<'ctx, '_>, dims: PointerValue<'ctx>) {
|
||||
ctx.builder.build_store(self.get_dims_ptr(ctx), dims);
|
||||
}
|
||||
|
||||
/// Convenience method for creating a new array storing dimension sizes with the given `size`.
|
||||
pub fn create_dims(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
size: IntValue<'ctx>,
|
||||
) {
|
||||
self.store_dims(ctx, ctx.builder.build_array_alloca(llvm_usize, size, ""));
|
||||
}
|
||||
|
||||
/// Returns a proxy object to the field storing the size of each dimension of this `NDArray`.
|
||||
pub fn get_dims(&self) -> NDArrayDimsProxy<'ctx> {
|
||||
NDArrayDimsProxy(self.clone())
|
||||
}
|
||||
|
||||
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
|
||||
/// on the field.
|
||||
fn get_data_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.get_ptr(),
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||
var_name.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the array of data elements `data` into this instance.
|
||||
fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, data: PointerValue<'ctx>) {
|
||||
ctx.builder.build_store(self.get_data_ptr(ctx), data);
|
||||
}
|
||||
|
||||
/// Convenience method for creating a new array storing data elements with the given element
|
||||
/// type `elem_ty` and
|
||||
/// `size`.
|
||||
pub fn create_data(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
elem_ty: BasicTypeEnum<'ctx>,
|
||||
size: IntValue<'ctx>,
|
||||
) {
|
||||
self.store_data(ctx, ctx.builder.build_array_alloca(elem_ty, size, ""));
|
||||
}
|
||||
|
||||
/// Returns a proxy object to the field storing the data of this `NDArray`.
|
||||
pub fn get_data(&self) -> NDArrayDataProxy<'ctx> {
|
||||
NDArrayDataProxy(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> Into<PointerValue<'ctx>> for NDArrayValue<'ctx> {
|
||||
fn into(self) -> PointerValue<'ctx> {
|
||||
self.get_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy type for accessing the `dims` array of an `NDArray` instance in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDArrayDimsProxy<'ctx>(NDArrayValue<'ctx>);
|
||||
|
||||
impl<'ctx> NDArrayDimsProxy<'ctx> {
|
||||
/// Returns the single-indirection pointer to the array.
|
||||
pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let var_name = self.0.1.map(|v| format!("{v}.dims")).unwrap_or_default();
|
||||
|
||||
ctx.builder.build_load(self.0.get_dims_ptr(ctx), var_name.as_str()).into_pointer_value()
|
||||
}
|
||||
|
||||
/// Returns the pointer to the size of the `idx`-th dimension.
|
||||
pub fn ptr_offset(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
idx: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let in_range = ctx.builder.build_int_compare(
|
||||
IntPredicate::ULT,
|
||||
idx,
|
||||
self.0.load_ndims(ctx),
|
||||
""
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(idx), Some(self.0.load_ndims(ctx)), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let var_name = name
|
||||
.map(|v| format!("{v}.addr"))
|
||||
.unwrap_or_default();
|
||||
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.get_ptr(ctx),
|
||||
&[idx],
|
||||
var_name.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size of the `idx`-th dimension.
|
||||
pub fn get(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
idx: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> IntValue<'ctx> {
|
||||
let ptr = self.ptr_offset(ctx, generator, idx, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default()).into_int_value()
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxy type for accessing the `data` array of an `NDArray` instance in LLVM.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct NDArrayDataProxy<'ctx>(NDArrayValue<'ctx>);
|
||||
|
||||
impl<'ctx> NDArrayDataProxy<'ctx> {
|
||||
/// Returns the single-indirection pointer to the array.
|
||||
pub fn get_ptr(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default();
|
||||
|
||||
ctx.builder.build_load(self.0.get_data_ptr(ctx), var_name.as_str()).into_pointer_value()
|
||||
}
|
||||
|
||||
pub unsafe fn ptr_to_data_flattened_unchecked(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
idx: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.get_ptr(ctx),
|
||||
&[idx],
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the pointer to the data at the `idx`-th flattened index.
|
||||
pub fn ptr_to_data_flattened(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
idx: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let ndims = self.0.load_ndims(ctx);
|
||||
let dims = self.0.get_dims().get_ptr(ctx);
|
||||
let data_sz = call_ndarray_calc_size(generator, ctx, ndims, dims);
|
||||
|
||||
let in_range = ctx.builder.build_int_compare(
|
||||
IntPredicate::ULT,
|
||||
idx,
|
||||
data_sz,
|
||||
""
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
in_range,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds with size {1}",
|
||||
[Some(idx), Some(self.0.load_ndims(ctx)), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
unsafe {
|
||||
self.ptr_to_data_flattened_unchecked(ctx, idx, name)
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn get_flattened_unchecked(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
idx: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_to_data_flattened_unchecked(ctx, idx, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Returns the data at the `idx`-th flattened index.
|
||||
pub fn get_flattened(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
idx: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_to_data_flattened(ctx, generator, idx, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
pub unsafe fn ptr_offset_unchecked(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
generator: &dyn CodeGenerator,
|
||||
indices: ListValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let indices_elem_ty = indices.get_data().get_ptr(ctx).get_type().get_element_type();
|
||||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
||||
panic!("Expected list[int32] but got {indices_elem_ty}")
|
||||
};
|
||||
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got {indices_elem_ty}");
|
||||
|
||||
let index = call_ndarray_flatten_index(
|
||||
generator,
|
||||
ctx,
|
||||
self.0,
|
||||
indices,
|
||||
).unwrap();
|
||||
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
self.get_ptr(ctx),
|
||||
&[index],
|
||||
name.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pointer to the data at the index specified by `indices`.
|
||||
pub fn ptr_offset(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ListValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let nidx_leq_ndims = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLE,
|
||||
indices.load_size(ctx, None),
|
||||
self.0.load_ndims(ctx),
|
||||
""
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
nidx_leq_ndims,
|
||||
"0:IndexError",
|
||||
"invalid index to scalar variable",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero());
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let indices_len = indices.load_size(ctx, None);
|
||||
let ndarray_len = self.0.load_ndims(ctx);
|
||||
|
||||
let min_fn_name = format!("llvm.umin.i{}", llvm_usize.get_bit_width());
|
||||
let min_fn = ctx.module.get_function(min_fn_name.as_str()).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[llvm_usize.into(), llvm_usize.into()],
|
||||
false
|
||||
);
|
||||
ctx.module.add_function(min_fn_name.as_str(), fn_type, None)
|
||||
});
|
||||
|
||||
let len = ctx
|
||||
.builder
|
||||
.build_call(min_fn, &[indices_len.into(), ndarray_len.into()], "")
|
||||
.try_as_basic_value()
|
||||
.map_left(|v| v.into_int_value())
|
||||
.left()
|
||||
.unwrap();
|
||||
|
||||
let i = ctx.builder.build_load(i_addr, "").into_int_value();
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::SLT, i, len, ""))
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let i = ctx.builder.build_load(i_addr, "").into_int_value();
|
||||
let (dim_idx, dim_sz) = unsafe {
|
||||
(
|
||||
indices.get_data().get_unchecked(ctx, i, None).into_int_value(),
|
||||
self.0.get_dims().get(ctx, generator, i, None),
|
||||
)
|
||||
};
|
||||
|
||||
let dim_lt = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim_idx,
|
||||
dim_sz,
|
||||
""
|
||||
);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
dim_lt,
|
||||
"0:IndexError",
|
||||
"index {0} is out of bounds for axis 0 with size {1}",
|
||||
[Some(dim_idx), Some(dim_sz), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
|
||||
ctx.builder.build_store(i_addr, i);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
).unwrap();
|
||||
|
||||
unsafe {
|
||||
self.ptr_offset_unchecked(ctx, generator, indices, name)
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn get_unsafe(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &dyn CodeGenerator,
|
||||
indices: ListValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset_unchecked(ctx, generator, indices, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Returns the data at the index specified by `indices`.
|
||||
pub fn get(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
indices: ListValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
let ptr = self.ptr_offset(ctx, generator, indices, name);
|
||||
ctx.builder.build_load(ptr, name.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -269,3 +269,39 @@ void __nac3_ndarray_calc_nd_indices64(
|
|||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_flatten_index(
|
||||
const uint32_t* dims,
|
||||
uint32_t num_dims,
|
||||
const uint32_t* indices,
|
||||
uint32_t num_indices
|
||||
) {
|
||||
uint32_t idx = 0;
|
||||
uint32_t stride = 1;
|
||||
for (uint32_t i = num_dims - 1; i-- >= 0; ) {
|
||||
if (i < num_indices) {
|
||||
idx += (stride * indices[i]);
|
||||
}
|
||||
|
||||
stride *= dims[i];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_flatten_index64(
|
||||
const uint64_t* dims,
|
||||
uint64_t num_dims,
|
||||
const uint32_t* indices,
|
||||
uint64_t num_indices
|
||||
) {
|
||||
uint64_t idx = 0;
|
||||
uint64_t stride = 1;
|
||||
for (uint64_t i = num_dims - 1; i-- >= 0; ) {
|
||||
if (i < num_indices) {
|
||||
idx += (stride * indices[i]);
|
||||
}
|
||||
|
||||
stride *= dims[i];
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use crate::typecheck::typedef::Type;
|
||||
|
||||
use super::{
|
||||
classes::ListValue,
|
||||
assert_is_ndarray,
|
||||
classes::{ListValue, NDArrayValue},
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
};
|
||||
|
@ -607,7 +606,7 @@ pub fn call_ndarray_calc_size<'ctx>(
|
|||
pub fn call_ndarray_init_dims<'ctx>(
|
||||
generator: &dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
shape: ListValue<'ctx>,
|
||||
) {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
|
@ -617,8 +616,6 @@ pub fn call_ndarray_init_dims<'ctx>(
|
|||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
assert_is_ndarray(ndarray);
|
||||
|
||||
let ndarray_init_dims_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_init_dims",
|
||||
64 => "__nac3_ndarray_init_dims64",
|
||||
|
@ -637,22 +634,14 @@ pub fn call_ndarray_init_dims<'ctx>(
|
|||
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
);
|
||||
let ndarray_dims = ndarray.get_dims();
|
||||
let shape_data = shape.get_data();
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
|
||||
ctx.builder.build_call(
|
||||
ndarray_init_dims_fn,
|
||||
&[
|
||||
ndarray_dims.into(),
|
||||
ndarray_dims.get_ptr(ctx).into(),
|
||||
shape_data.get_ptr(ctx).into(),
|
||||
ndarray_num_dims.into(),
|
||||
],
|
||||
|
@ -669,12 +658,9 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
generator: &dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
index: IntValue<'ctx>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert_is_ndarray(ndarray);
|
||||
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
@ -698,16 +684,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
let ndarray_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.get_dims();
|
||||
|
||||
let indices = ctx.builder.build_array_alloca(
|
||||
llvm_usize,
|
||||
|
@ -719,7 +697,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
ndarray_calc_nd_indices_fn,
|
||||
&[
|
||||
index.into(),
|
||||
ndarray_dims.into(),
|
||||
ndarray_dims.get_ptr(ctx).into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.into(),
|
||||
],
|
||||
|
@ -728,3 +706,63 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
|
||||
Ok(indices)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_flatten_index`.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
||||
pub fn call_ndarray_flatten_index<'ctx>(
|
||||
generator: &dyn CodeGenerator,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
indices: ListValue<'ctx>,
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_flatten_index",
|
||||
64 => "__nac3_ndarray_flatten_index64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_pi32.into(),
|
||||
llvm_pusize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
let ndarray_dims = ndarray.get_dims();
|
||||
let indices_size = indices.load_size(ctx, None);
|
||||
let indices_data = indices.get_data();
|
||||
|
||||
let index = ctx.builder
|
||||
.build_call(
|
||||
ndarray_flatten_index_fn,
|
||||
&[
|
||||
ndarray_num_dims.into(),
|
||||
ndarray_dims.get_ptr(ctx).into(),
|
||||
indices_size.into(),
|
||||
indices_data.get_ptr(ctx).into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.try_as_basic_value()
|
||||
.map_left(|v| v.into_int_value())
|
||||
.left()
|
||||
.unwrap();
|
||||
|
||||
Ok(index)
|
||||
}
|
|
@ -34,9 +34,6 @@ use std::sync::{
|
|||
};
|
||||
use std::thread;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
use inkwell::types::AnyTypeEnum;
|
||||
|
||||
pub mod classes;
|
||||
pub mod concrete_type;
|
||||
pub mod expr;
|
||||
|
@ -999,27 +996,3 @@ fn gen_in_range_check<'ctx>(
|
|||
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
|
||||
}
|
||||
|
||||
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
|
||||
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let llvm_ndarray_ty = value.get_type().get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
|
||||
};
|
||||
|
||||
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
|
||||
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
|
||||
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
|
||||
unreachable!()
|
||||
};
|
||||
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
|
||||
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
|
||||
};
|
||||
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
|
||||
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use inkwell::values::{ArrayValue, IntValue};
|
|||
use nac3parser::ast::StrRef;
|
||||
use crate::{
|
||||
codegen::{
|
||||
classes::ListValue,
|
||||
classes::{ListValue, NDArrayValue},
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::{
|
||||
|
@ -27,11 +27,10 @@ fn create_ndarray_const_shape<'ctx, 'a>(
|
|||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: ArrayValue<'ctx>
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
|
@ -68,54 +67,18 @@ fn create_ndarray_const_shape<'ctx, 'a>(
|
|||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
||||
|
||||
let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false);
|
||||
ndarray.store_ndims(ctx, generator, num_dims);
|
||||
|
||||
let ndarray_num_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(ndarray_num_dims, num_dims);
|
||||
|
||||
let ndarray_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
ctx.builder.build_store(
|
||||
ndarray_dims,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_usize,
|
||||
ndarray_num_dims,
|
||||
"",
|
||||
),
|
||||
);
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
for i in 0..shape.get_type().len() {
|
||||
let ndarray_dim = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
let ndarray_dim = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray_dim,
|
||||
&[llvm_i32.const_int(i as u64, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
let ndarray_dim = ndarray
|
||||
.get_dims()
|
||||
.ptr_offset(ctx, generator, llvm_usize.const_int(i as u64, true), None);
|
||||
let shape_dim = ctx.builder.build_extract_value(shape, i, "")
|
||||
.map(|val| val.into_int_value())
|
||||
.unwrap();
|
||||
|
@ -123,42 +86,14 @@ fn create_ndarray_const_shape<'ctx, 'a>(
|
|||
ctx.builder.build_store(ndarray_dim, shape_dim);
|
||||
}
|
||||
|
||||
let (ndarray_num_dims, ndarray_dims) = unsafe {
|
||||
(
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
""
|
||||
),
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
""
|
||||
),
|
||||
)
|
||||
};
|
||||
let ndarray_dims = ndarray.get_dims().get_ptr(ctx);
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
|
||||
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
|
||||
);
|
||||
|
||||
let ndarray_data = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(
|
||||
ndarray_data,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_ndarray_data_t,
|
||||
ndarray_num_elems,
|
||||
""
|
||||
),
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray_dims,
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
@ -214,7 +149,7 @@ fn call_ndarray_empty_impl<'ctx, 'a>(
|
|||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: ListValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||
|
||||
|
@ -284,79 +219,23 @@ fn call_ndarray_empty_impl<'ctx, 'a>(
|
|||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
let ndarray = NDArrayValue::from_ptr_val(ndarray, llvm_usize, None);
|
||||
|
||||
let num_dims = shape.load_size(ctx, None);
|
||||
ndarray.store_ndims(ctx, generator, num_dims);
|
||||
|
||||
let ndarray_num_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(ndarray_num_dims, num_dims);
|
||||
|
||||
let ndarray_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
ctx.builder.build_store(
|
||||
ndarray_dims,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_usize,
|
||||
ndarray_num_dims,
|
||||
"",
|
||||
),
|
||||
);
|
||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
||||
ndarray.create_dims(ctx, llvm_usize, ndarray_num_dims);
|
||||
|
||||
call_ndarray_init_dims(generator, ctx, ndarray, shape);
|
||||
|
||||
let (ndarray_num_dims, ndarray_dims) = unsafe {
|
||||
(
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
""
|
||||
),
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
""
|
||||
),
|
||||
)
|
||||
};
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
|
||||
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
|
||||
);
|
||||
|
||||
let ndarray_data = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(
|
||||
ndarray_data,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_ndarray_data_t,
|
||||
ndarray_num_elems,
|
||||
"",
|
||||
),
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
);
|
||||
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
@ -369,35 +248,19 @@ fn call_ndarray_empty_impl<'ctx, 'a>(
|
|||
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (num_dims, dims) = unsafe {
|
||||
(
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
""
|
||||
),
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
""
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.builder.build_load(num_dims, "").into_int_value(),
|
||||
ctx.builder.build_load(dims, "").into_pointer_value(),
|
||||
ndarray.load_ndims(ctx),
|
||||
ndarray.get_dims().get_ptr(ctx),
|
||||
);
|
||||
|
||||
gen_for_callback(
|
||||
|
@ -417,21 +280,11 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
|||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, ""))
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let ndarray_data = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||
None
|
||||
).into_pointer_value();
|
||||
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
let elem = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray_data,
|
||||
&[i],
|
||||
""
|
||||
)
|
||||
ndarray.get_data().ptr_to_data_flattened_unchecked(ctx, i, None)
|
||||
};
|
||||
|
||||
let value = value_fn(generator, ctx, i)?;
|
||||
|
@ -459,7 +312,7 @@ fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
|||
fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
|
@ -491,7 +344,7 @@ fn call_ndarray_zeros_impl<'ctx, 'a>(
|
|||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: ListValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let supported_types = [
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
|
@ -527,7 +380,7 @@ fn call_ndarray_ones_impl<'ctx, 'a>(
|
|||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: ListValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let supported_types = [
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
|
@ -564,7 +417,7 @@ fn call_ndarray_full_impl<'ctx, 'a>(
|
|||
elem_ty: Type,
|
||||
shape: ListValue<'ctx>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
|
@ -633,7 +486,7 @@ fn call_ndarray_eye_impl<'ctx, 'a>(
|
|||
nrows: IntValue<'ctx>,
|
||||
ncols: IntValue<'ctx>,
|
||||
offset: IntValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
) -> Result<NDArrayValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize_2 = llvm_usize.array_type(2);
|
||||
|
@ -718,7 +571,7 @@ pub fn gen_ndarray_empty<'ctx, 'a>(
|
|||
context,
|
||||
context.primitives.float,
|
||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
||||
)
|
||||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.zeros`.
|
||||
|
@ -742,7 +595,7 @@ pub fn gen_ndarray_zeros<'ctx, 'a>(
|
|||
context,
|
||||
context.primitives.float,
|
||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
||||
)
|
||||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.ones`.
|
||||
|
@ -766,7 +619,7 @@ pub fn gen_ndarray_ones<'ctx, 'a>(
|
|||
context,
|
||||
context.primitives.float,
|
||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
||||
)
|
||||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.full`.
|
||||
|
@ -794,7 +647,7 @@ pub fn gen_ndarray_full<'ctx, 'a>(
|
|||
fill_value_ty,
|
||||
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
|
||||
fill_value_arg,
|
||||
)
|
||||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.eye`.
|
||||
|
@ -839,7 +692,7 @@ pub fn gen_ndarray_eye<'ctx, 'a>(
|
|||
nrows_arg.into_int_value(),
|
||||
ncols_arg.into_int_value(),
|
||||
offset_arg.into_int_value(),
|
||||
)
|
||||
).map(NDArrayValue::into)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.identity`.
|
||||
|
@ -866,5 +719,5 @@ pub fn gen_ndarray_identity<'ctx, 'a>(
|
|||
n_arg.into_int_value(),
|
||||
n_arg.into_int_value(),
|
||||
llvm_usize.const_zero(),
|
||||
)
|
||||
).map(NDArrayValue::into)
|
||||
}
|
Loading…
Reference in New Issue