1
0
forked from M-Labs/nac3

[core] codegen: Implement ContiguousNDArray

Fixes compatibility with linalg algorithms. matrix_power is missing due
to the need for indexing support.
This commit is contained in:
David Mak 2024-11-28 12:45:17 +08:00
parent 57da0f67d1
commit e965a7c7ce
7 changed files with 835 additions and 308 deletions

View File

@ -14,6 +14,7 @@ use super::{
numpy,
numpy::ndarray_elementwise_unaryop_impl,
stmt::gen_for_callback_incrementing,
types::NDArrayType,
values::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
@ -1982,288 +1983,290 @@ fn build_output_struct<'ctx>(
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
.construct_uninitialized(generator, ctx, llvm_usize.const_int(2, false), None);
out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = out.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_cholesky(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
Ok(out.as_base_value().into())
}
/// Invokes the `np_linalg_qr` linalg function
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only");
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
unsafe { q.create_data(generator, ctx) };
let r = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None);
unsafe { r.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let q_c = q.make_contiguous_ndarray(generator, ctx);
let r_c = r.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_qr(
ctx,
x1_c.as_base_value().into(),
q_c.as_base_value().into(),
r_c.as_base_value().into(),
None,
);
let q = q.as_base_value().into();
let r = r.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![q, r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
}
/// Invokes the `np_linalg_svd` linalg function
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1));
let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None);
unsafe { u.create_data(generator, ctx) };
let s = out_ndarray1_ty.construct_dyn_shape(generator, ctx, &[dk], None);
unsafe { s.create_data(generator, ctx) };
let vh = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d1, d1], None);
unsafe { vh.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let u_c = u.make_contiguous_ndarray(generator, ctx);
let s_c = s.make_contiguous_ndarray(generator, ctx);
let vh_c = vh.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_svd(
ctx,
x1_c.as_base_value().into(),
u_c.as_base_value().into(),
s_c.as_base_value().into(),
vh_c.as_base_value().into(),
None,
);
let u = u.as_base_value().into();
let s = s.as_base_value().into();
let vh = vh.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![u, s, vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
}
/// Invokes the `np_linalg_inv` linalg function
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
.construct_uninitialized(generator, ctx, llvm_usize.const_int(2, false), None);
out.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = out.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_inv(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
Ok(out.as_base_value().into())
}
/// Invokes the `np_linalg_pinv` linalg function
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
.construct_dyn_shape(generator, ctx, &[d0, d1], None);
unsafe { out.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = out.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_pinv(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
Ok(out.as_base_value().into())
}
/// Invokes the `sp_linalg_lu` linalg function
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let dim1 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
.into_int_value()
};
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let x1_shape = x1.shape();
let d0 =
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
let d1 = unsafe {
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
};
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
unsafe { l.create_data(generator, ctx) };
let u = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None);
unsafe { u.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let l_c = l.make_contiguous_ndarray(generator, ctx);
let u_c = u.make_contiguous_ndarray(generator, ctx);
extern_fns::call_sp_linalg_lu(
ctx,
x1_c.as_base_value().into(),
l_c.as_base_value().into(),
u_c.as_base_value().into(),
None,
);
let l = l.as_base_value().into();
let u = u.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![l, u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
}
/// Invokes the `np_linalg_matrix_power` linalg function
@ -2334,124 +2337,156 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
// Changing second parameter to a `NDArray` for uniformity in function call
let out = numpy::create_ndarray_const_shape(
generator,
ctx,
elem_ty,
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
// The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call.
let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1))
.construct_const_shape(generator, ctx, &[1], None);
unsafe { det.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let out_c = det.make_contiguous_ndarray(generator, ctx);
extern_fns::call_np_linalg_det(
ctx,
x1_c.as_base_value().into(),
out_c.as_base_value().into(),
None,
);
// Get the determinant out of `out`
let det = unsafe { det.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(det)
}
/// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
assert_eq!(ndims, 2);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let t = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
t.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { t.create_data(generator, ctx) };
let z = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
z.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { z.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let t_c = t.make_contiguous_ndarray(generator, ctx);
let z_c = z.make_contiguous_ndarray(generator, ctx);
extern_fns::call_sp_linalg_schur(
ctx,
x1_c.as_base_value().into(),
t_c.as_base_value().into(),
z_c.as_base_value().into(),
None,
);
let t = t.as_base_value().into();
let z = z.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![t, z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
}
/// Invokes the `sp_linalg_hessenberg` linalg function
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
};
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let ndims = extract_ndims(&ctx.unifier, ndims);
assert_eq!(ndims, 2);
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
let dim0 = unsafe {
n1.shape()
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
};
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
Ok(ctx
.builder
.build_load(out_ptr, "Hessenberg_decomposition_result")
.map(Into::into)
.unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
if !x1.get_type().element_type().is_float_type() {
unsupported_type(ctx, FN_NAME, &[x1_ty]);
}
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
let h = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
h.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { h.create_data(generator, ctx) };
let q = out_ndarray_ty.construct_uninitialized(
generator,
ctx,
llvm_usize.const_int(2, false),
None,
);
q.copy_shape_from_ndarray(generator, ctx, x1);
unsafe { q.create_data(generator, ctx) };
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
let h_c = h.make_contiguous_ndarray(generator, ctx);
let q_c = q.make_contiguous_ndarray(generator, ctx);
extern_fns::call_sp_linalg_hessenberg(
ctx,
x1_c.as_base_value().into(),
h_c.as_base_value().into(),
q_c.as_base_value().into(),
None,
);
let h = h.as_base_value().into();
let q = q.as_base_value().into();
let out_ptr = build_output_struct(ctx, vec![h, q]);
Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap())
}

View File

@ -0,0 +1,253 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
AddressSpace,
};
use itertools::Itertools;
use nac3core_derive::StructFields;
use super::ProxyType;
use crate::{
codegen::{
types::structure::{FieldIndexCounter, StructField, StructFields},
values::{ArraySliceValue, ContiguousNDArrayValue, ProxyValue},
CodeGenContext, CodeGenerator,
},
toplevel::numpy::unpack_ndarray_var_tys,
typecheck::typedef::Type,
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ContiguousNDArrayType<'ctx> {
ty: PointerType<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
pub struct ContiguousNDArrayFields<'ctx> {
#[value_type(usize)]
pub ndims: StructField<'ctx, IntValue<'ctx>>,
#[value_type(usize.ptr_type(AddressSpace::default()))]
pub shape: StructField<'ctx, PointerValue<'ctx>>,
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
pub data: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> ContiguousNDArrayFields<'ctx> {
#[must_use]
pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
let mut counter = FieldIndexCounter::default();
ContiguousNDArrayFields {
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
data: StructField::create(&mut counter, "data", item.ptr_type(AddressSpace::default())),
}
}
}
impl<'ctx> ContiguousNDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let ctx = llvm_ty.get_context();
let fields = ContiguousNDArrayFields::new(ctx, llvm_usize);
let llvm_expected_ty = fields.to_vec();
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
return Err(format!(
"Expected {} fields in `ContiguousNDArray`, got {}",
llvm_expected_ty.len(),
llvm_ndarray_ty.count_fields()
));
}
llvm_expected_ty
.iter()
.enumerate()
.map(|(i, expected_ty)| {
(expected_ty.0, expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
})
.try_for_each(|(field_name, expected_ty, actual_ty)| {
if field_name == fields.data.name() {
if actual_ty.is_pointer_type() {
Ok(())
} else {
Err(format!("Expected T* for `ContiguousNDArray.{field_name}`, got {actual_ty}"))
}
} else if expected_ty == actual_ty {
Ok(())
} else {
Err(format!("Expected {expected_ty} for `ContiguousNDArray.{field_name}`, got {actual_ty}"))
}
})?;
Ok(())
}
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
#[must_use]
fn fields(
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> ContiguousNDArrayFields<'ctx> {
ContiguousNDArrayFields::new_typed(item, llvm_usize)
}
/// See [`NDArrayType::fields`].
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> {
Self::fields(self.item, self.llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
#[must_use]
fn llvm_type(
ctx: &'ctx Context,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> PointerType<'ctx> {
let field_tys =
Self::fields(item, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}
/// Creates an instance of [`ContiguousNDArrayType`].
#[must_use]
pub fn new<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
item: BasicTypeEnum<'ctx>,
) -> Self {
let llvm_usize = generator.get_size_type(ctx);
let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize);
Self { ty: llvm_cndarray, item, llvm_usize }
}
/// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type].
#[must_use]
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type,
) -> Self {
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
let llvm_usize = generator.get_size_type(ctx.ctx);
Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize }
}
/// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`.
#[must_use]
pub fn from_type(
ptr_ty: PointerType<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Self {
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
Self { ty: ptr_ty, item, llvm_usize }
}
/// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type.
#[must_use]
pub fn alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self as ProxyType<'ctx>>::Value {
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
self.raw_alloca(generator, ctx, name),
self.item,
self.llvm_usize,
name,
)
}
}
impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ContiguousNDArrayValue<'ctx>;
fn is_type<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: impl BasicType<'ctx>,
) -> Result<(), String> {
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
} else {
Err(format!("Expected pointer type, got {llvm_ty:?}"))
}
}
fn is_representable<G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &'ctx Context,
llvm_ty: Self::Base,
) -> Result<(), String> {
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
}
fn raw_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: Option<&'ctx str>,
) -> <Self::Value as ProxyValue<'ctx>>::Base {
generator
.gen_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
name,
)
.unwrap()
}
fn array_alloca<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
size: IntValue<'ctx>,
name: Option<&'ctx str>,
) -> ArraySliceValue<'ctx> {
generator
.gen_array_var_alloc(
ctx,
self.as_base_type().get_element_type().into_struct_type().into(),
size,
name,
)
.unwrap()
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ContiguousNDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: ContiguousNDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}

View File

@ -20,6 +20,9 @@ use crate::{
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
typecheck::typedef::Type,
};
pub use contiguous::*;
mod contiguous;
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]

View File

@ -103,6 +103,12 @@ where
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
}
/// Returns the name of this field.
#[must_use]
pub fn name(&self) -> &'static str {
self.name
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep(

View File

@ -0,0 +1,206 @@
use inkwell::{
types::{BasicType, BasicTypeEnum, IntType},
values::{IntValue, PointerValue},
AddressSpace,
};
use super::{ArrayLikeValue, NDArrayValue, ProxyValue};
use crate::codegen::{
stmt::gen_if_callback,
types::{structure::StructField, ContiguousNDArrayType, NDArrayType},
CodeGenContext, CodeGenerator,
};
#[derive(Copy, Clone)]
pub struct ContiguousNDArrayValue<'ctx> {
value: PointerValue<'ctx>,
item: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ContiguousNDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is
/// not an instance.
pub fn is_representable(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
}
/// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_pointer_value(
ptr: PointerValue<'ctx>,
dtype: BasicTypeEnum<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
Self { value: ptr, item: dtype, llvm_usize, name }
}
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
self.get_type().get_fields().ndims
}
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
self.ndims_field().set(ctx, self.as_base_value(), value, self.name);
}
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields().shape
}
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.shape_field().set(ctx, self.as_base_value(), value, self.name);
}
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.shape_field().get(ctx, self.value, self.name)
}
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
self.get_type().get_fields().data
}
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
self.data_field().set(ctx, self.as_base_value(), value, self.name);
}
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.data_field().get(ctx, self.value, self.name)
}
}
impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ContiguousNDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
<Self as ProxyValue<'ctx>>::Type::from_type(
self.as_base_value().get_type(),
self.item,
self.llvm_usize,
)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ContiguousNDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ContiguousNDArrayValue<'ctx>) -> Self {
value.as_base_value()
}
}
impl<'ctx> NDArrayValue<'ctx> {
/// Create a [`ContiguousNDArray`] from the contents of this ndarray.
///
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
///
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of
/// the returned [`ContiguousNDArray`] and copy contents of this ndarray to there.
///
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`ContiguousNDArray`]
/// will share memory with this ndarray.
///
/// The `item_model` sets the [`Model`] of the returned [`ContiguousNDArray`]'s `Item` model for type-safety, and
/// should match the `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics. Use model [`Any`]
/// if you don't care/cannot know the [`Model`] in advance.
pub fn make_contiguous_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ContiguousNDArrayValue<'ctx> {
let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype)
.alloca(generator, ctx, self.name);
// Set ndims and shape.
let ndims = self
.ndims
.map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false));
result.store_ndims(ctx, ndims);
let shape = self.shape();
result.store_shape(ctx, shape.base_ptr(ctx, generator));
gen_if_callback(
generator,
ctx,
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|_, ctx| {
// This ndarray is contiguous.
let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name);
let data = ctx
.builder
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
.unwrap();
result.store_data(ctx, data);
Ok(())
},
|generator, ctx| {
// This ndarray is not contiguous. Do a full-copy on `data`. `make_copy` produces an
// ndarray with contiguous `data`.
let copied_ndarray = self.make_copy(generator, ctx);
let data = copied_ndarray.data().base_ptr(ctx, generator);
let data = ctx
.builder
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
.unwrap();
result.store_data(ctx, data);
Ok(())
},
)
.unwrap();
result
}
/// Create an [`NDArrayObject`] from a [`ContiguousNDArray`].
///
/// The operation is super cheap. The newly created [`NDArrayObject`] will share the
/// same memory as the [`ContiguousNDArray`].
///
/// `ndims` has to be provided as [`NDArrayObject`] requires a statically known `ndims` value, despite
/// the fact that the information should be contained within the [`ContiguousNDArray`].
pub fn from_contiguous_ndarray<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
carray: ContiguousNDArrayValue<'ctx>,
ndims: u64,
) -> Self {
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
// Allocate the resulting ndarray.
let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims))
.construct_uninitialized(
generator,
ctx,
carray.llvm_usize.const_int(ndims, false),
carray.name,
);
// Copy shape and update strides
let shape = carray.load_shape(ctx);
ndarray.copy_shape_from_array(generator, ctx, shape);
ndarray.set_strides_contiguous(generator, ctx);
// Share data
let data = carray.load_data(ctx);
ndarray.store_data(
ctx,
ctx.builder
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
.unwrap(),
);
ndarray
}
}

View File

@ -16,6 +16,9 @@ use crate::codegen::{
types::{structure::StructField, NDArrayType},
CodeGenContext, CodeGenerator,
};
pub use contiguous::*;
mod contiguous;
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
@ -362,6 +365,29 @@ impl<'ctx> NDArrayValue<'ctx> {
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
}
#[must_use]
pub fn make_copy<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
let clone = self.get_type().construct_uninitialized(
generator,
ctx,
self.ndims.map_or_else(
|| self.load_ndims(ctx),
|ndims| self.llvm_usize.const_int(ndims, false),
),
None,
);
let shape = self.shape();
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
unsafe { clone.create_data(generator, ctx) };
clone.copy_data_from(generator, ctx, *self);
clone
}
/// Copy data from another ndarray.
///
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes

View File

@ -1758,16 +1758,14 @@ def run() -> int32:
test_ndarray_transpose()
test_ndarray_reshape()
test_ndarray_dot()
# test_ndarray_cholesky()
# test_ndarray_qr()
# test_ndarray_svd()
# test_ndarray_linalg_inv()
# test_ndarray_pinv()
test_ndarray_cholesky()
test_ndarray_qr()
test_ndarray_svd()
test_ndarray_linalg_inv()
test_ndarray_pinv()
# test_ndarray_matrix_power()
# test_ndarray_det()
# test_ndarray_lu()
# test_ndarray_schur()
# test_ndarray_hessenberg()
test_ndarray_det()
test_ndarray_lu()
test_ndarray_schur()
test_ndarray_hessenberg()
return 0