forked from M-Labs/nac3
[core] codegen: Implement ContiguousNDArray
Fixes compatibility with linalg algorithms. matrix_power is missing due to the need for indexing support.
This commit is contained in:
parent
57da0f67d1
commit
e965a7c7ce
@ -14,6 +14,7 @@ use super::{
|
|||||||
numpy,
|
numpy,
|
||||||
numpy::ndarray_elementwise_unaryop_impl,
|
numpy::ndarray_elementwise_unaryop_impl,
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
|
types::NDArrayType,
|
||||||
values::{
|
values::{
|
||||||
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
@ -1982,288 +1983,290 @@ fn build_output_struct<'ctx>(
|
|||||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_linalg_cholesky";
|
const FN_NAME: &str = "np_linalg_cholesky";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
let dim0 = unsafe {
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
|
|
||||||
Ok(out)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
||||||
|
.construct_uninitialized(generator, ctx, llvm_usize.const_int(2, false), None);
|
||||||
|
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
unsafe { out.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let out_c = out.make_contiguous_ndarray(generator, ctx);
|
||||||
|
extern_fns::call_np_linalg_cholesky(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
out_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_linalg_qr` linalg function
|
/// Invokes the `np_linalg_qr` linalg function
|
||||||
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_linalg_qr";
|
const FN_NAME: &str = "np_linalg_qr";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
let dim0 = unsafe {
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
|
||||||
|
|
||||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
|
|
||||||
|
|
||||||
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let x1_shape = x1.shape();
|
||||||
|
let d0 =
|
||||||
|
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
let d1 = unsafe {
|
||||||
|
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
};
|
||||||
|
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
||||||
|
|
||||||
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
||||||
|
let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
|
||||||
|
unsafe { q.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let r = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None);
|
||||||
|
unsafe { r.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let q_c = q.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let r_c = r.make_contiguous_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_qr(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
q_c.as_base_value().into(),
|
||||||
|
r_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let q = q.as_base_value().into();
|
||||||
|
let r = r.as_base_value().into();
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![q, r]);
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_linalg_svd` linalg function
|
/// Invokes the `np_linalg_svd` linalg function
|
||||||
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_linalg_svd";
|
const FN_NAME: &str = "np_linalg_svd";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
|
||||||
|
|
||||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
|
|
||||||
|
|
||||||
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
|
|
||||||
|
|
||||||
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let x1_shape = x1.shape();
|
||||||
|
let d0 =
|
||||||
|
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
let d1 = unsafe {
|
||||||
|
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
};
|
||||||
|
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
||||||
|
|
||||||
|
let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1));
|
||||||
|
let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
||||||
|
|
||||||
|
let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None);
|
||||||
|
unsafe { u.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let s = out_ndarray1_ty.construct_dyn_shape(generator, ctx, &[dk], None);
|
||||||
|
unsafe { s.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let vh = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d1, d1], None);
|
||||||
|
unsafe { vh.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let u_c = u.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let s_c = s.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let vh_c = vh.make_contiguous_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
extern_fns::call_np_linalg_svd(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
u_c.as_base_value().into(),
|
||||||
|
s_c.as_base_value().into(),
|
||||||
|
vh_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let u = u.as_base_value().into();
|
||||||
|
let s = s.as_base_value().into();
|
||||||
|
let vh = vh.as_base_value().into();
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![u, s, vh]);
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_linalg_inv` linalg function
|
/// Invokes the `np_linalg_inv` linalg function
|
||||||
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_linalg_inv";
|
const FN_NAME: &str = "np_linalg_inv";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
let dim0 = unsafe {
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
|
|
||||||
Ok(out)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
||||||
|
.construct_uninitialized(generator, ctx, llvm_usize.const_int(2, false), None);
|
||||||
|
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
unsafe { out.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let out_c = out.make_contiguous_ndarray(generator, ctx);
|
||||||
|
extern_fns::call_np_linalg_inv(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
out_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_linalg_pinv` linalg function
|
/// Invokes the `np_linalg_pinv` linalg function
|
||||||
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_linalg_pinv";
|
const FN_NAME: &str = "np_linalg_pinv";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
|
|
||||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
|
|
||||||
Ok(out)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let x1_shape = x1.shape();
|
||||||
|
let d0 =
|
||||||
|
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
let d1 = unsafe {
|
||||||
|
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2))
|
||||||
|
.construct_dyn_shape(generator, ctx, &[d0, d1], None);
|
||||||
|
unsafe { out.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let out_c = out.make_contiguous_ndarray(generator, ctx);
|
||||||
|
extern_fns::call_np_linalg_pinv(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
out_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `sp_linalg_lu` linalg function
|
/// Invokes the `sp_linalg_lu` linalg function
|
||||||
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "sp_linalg_lu";
|
const FN_NAME: &str = "sp_linalg_lu";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let dim1 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
|
||||||
|
|
||||||
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
|
|
||||||
|
|
||||||
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
|
|
||||||
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let x1_shape = x1.shape();
|
||||||
|
let d0 =
|
||||||
|
unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
let d1 = unsafe {
|
||||||
|
x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||||
|
};
|
||||||
|
let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None);
|
||||||
|
|
||||||
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
||||||
|
|
||||||
|
let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None);
|
||||||
|
unsafe { l.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let u = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None);
|
||||||
|
unsafe { u.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let l_c = l.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let u_c = u.make_contiguous_ndarray(generator, ctx);
|
||||||
|
extern_fns::call_sp_linalg_lu(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
l_c.as_base_value().into(),
|
||||||
|
u_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let l = l.as_base_value().into();
|
||||||
|
let u = u.as_base_value().into();
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![l, u]);
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_linalg_matrix_power` linalg function
|
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||||
@ -2334,124 +2337,156 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
if let BasicValueEnum::PointerValue(_) = x1 {
|
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let out = numpy::create_ndarray_const_shape(
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
generator,
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
ctx,
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
elem_ty,
|
|
||||||
&[llvm_usize.const_int(1, false)],
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
)
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
.unwrap();
|
|
||||||
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
|
||||||
let res =
|
|
||||||
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
|
||||||
Ok(res)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call.
|
||||||
|
let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1))
|
||||||
|
.construct_const_shape(generator, ctx, &[1], None);
|
||||||
|
unsafe { det.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let out_c = det.make_contiguous_ndarray(generator, ctx);
|
||||||
|
extern_fns::call_np_linalg_det(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
out_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get the determinant out of `out`
|
||||||
|
let det = unsafe { det.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
Ok(det)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `sp_linalg_schur` linalg function
|
/// Invokes the `sp_linalg_schur` linalg function
|
||||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "sp_linalg_schur";
|
const FN_NAME: &str = "sp_linalg_schur";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
assert_eq!(ndims, 2);
|
||||||
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
|
|
||||||
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
|
|
||||||
|
|
||||||
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
|
|
||||||
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
||||||
|
|
||||||
|
let t = out_ndarray_ty.construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
t.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
unsafe { t.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let z = out_ndarray_ty.construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
z.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
unsafe { z.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let t_c = t.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let z_c = z.make_contiguous_ndarray(generator, ctx);
|
||||||
|
extern_fns::call_sp_linalg_schur(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
t_c.as_base_value().into(),
|
||||||
|
z_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let t = t.as_base_value().into();
|
||||||
|
let z = z.as_base_value().into();
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![t, z]);
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `sp_linalg_hessenberg` linalg function
|
/// Invokes the `sp_linalg_hessenberg` linalg function
|
||||||
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "sp_linalg_hessenberg";
|
const FN_NAME: &str = "sp_linalg_hessenberg";
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) };
|
||||||
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
|
||||||
|
|
||||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
let (elem_ty, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
};
|
assert_eq!(ndims, 2);
|
||||||
|
let x1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||||
|
let x1 = NDArrayValue::from_pointer_value(x1, x1_elem_ty, Some(ndims), llvm_usize, None);
|
||||||
|
|
||||||
let n1 = NDArrayValue::from_pointer_value(n1, n1_elem_ty, Some(ndims), llvm_usize, None);
|
if !x1.get_type().element_type().is_float_type() {
|
||||||
|
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||||
let dim0 = unsafe {
|
|
||||||
n1.shape()
|
|
||||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|
||||||
.into_int_value()
|
|
||||||
};
|
|
||||||
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
|
||||||
.unwrap()
|
|
||||||
.as_base_value()
|
|
||||||
.as_basic_value_enum();
|
|
||||||
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
|
|
||||||
|
|
||||||
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
|
|
||||||
Ok(ctx
|
|
||||||
.builder
|
|
||||||
.build_load(out_ptr, "Hessenberg_decomposition_result")
|
|
||||||
.map(Into::into)
|
|
||||||
.unwrap())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2));
|
||||||
|
|
||||||
|
let h = out_ndarray_ty.construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
h.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
unsafe { h.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let q = out_ndarray_ty.construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
llvm_usize.const_int(2, false),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
q.copy_shape_from_ndarray(generator, ctx, x1);
|
||||||
|
unsafe { q.create_data(generator, ctx) };
|
||||||
|
|
||||||
|
let x1_c = x1.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let h_c = h.make_contiguous_ndarray(generator, ctx);
|
||||||
|
let q_c = q.make_contiguous_ndarray(generator, ctx);
|
||||||
|
extern_fns::call_sp_linalg_hessenberg(
|
||||||
|
ctx,
|
||||||
|
x1_c.as_base_value().into(),
|
||||||
|
h_c.as_base_value().into(),
|
||||||
|
q_c.as_base_value().into(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let h = h.as_base_value().into();
|
||||||
|
let q = q.as_base_value().into();
|
||||||
|
let out_ptr = build_output_struct(ctx, vec![h, q]);
|
||||||
|
Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap())
|
||||||
}
|
}
|
||||||
|
253
nac3core/src/codegen/types/ndarray/contiguous.rs
Normal file
253
nac3core/src/codegen/types/ndarray/contiguous.rs
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
use inkwell::{
|
||||||
|
context::Context,
|
||||||
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
|
|
||||||
|
use nac3core_derive::StructFields;
|
||||||
|
|
||||||
|
use super::ProxyType;
|
||||||
|
use crate::{
|
||||||
|
codegen::{
|
||||||
|
types::structure::{FieldIndexCounter, StructField, StructFields},
|
||||||
|
values::{ArraySliceValue, ContiguousNDArrayValue, ProxyValue},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
toplevel::numpy::unpack_ndarray_var_tys,
|
||||||
|
typecheck::typedef::Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
pub struct ContiguousNDArrayType<'ctx> {
|
||||||
|
ty: PointerType<'ctx>,
|
||||||
|
item: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone, Copy, StructFields)]
|
||||||
|
pub struct ContiguousNDArrayFields<'ctx> {
|
||||||
|
#[value_type(usize)]
|
||||||
|
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||||
|
#[value_type(usize.ptr_type(AddressSpace::default()))]
|
||||||
|
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
#[value_type(i8_type().ptr_type(AddressSpace::default()))]
|
||||||
|
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ContiguousNDArrayFields<'ctx> {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
|
||||||
|
let mut counter = FieldIndexCounter::default();
|
||||||
|
|
||||||
|
ContiguousNDArrayFields {
|
||||||
|
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
||||||
|
shape: StructField::create(
|
||||||
|
&mut counter,
|
||||||
|
"shape",
|
||||||
|
llvm_usize.ptr_type(AddressSpace::default()),
|
||||||
|
),
|
||||||
|
data: StructField::create(&mut counter, "data", item.ptr_type(AddressSpace::default())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ContiguousNDArrayType<'ctx> {
|
||||||
|
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
||||||
|
pub fn is_representable(
|
||||||
|
llvm_ty: PointerType<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
let ctx = llvm_ty.get_context();
|
||||||
|
|
||||||
|
let fields = ContiguousNDArrayFields::new(ctx, llvm_usize);
|
||||||
|
let llvm_expected_ty = fields.to_vec();
|
||||||
|
|
||||||
|
let llvm_ndarray_ty = llvm_ty.get_element_type();
|
||||||
|
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||||
|
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
|
||||||
|
};
|
||||||
|
if llvm_ndarray_ty.count_fields() != u32::try_from(llvm_expected_ty.len()).unwrap() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected {} fields in `ContiguousNDArray`, got {}",
|
||||||
|
llvm_expected_ty.len(),
|
||||||
|
llvm_ndarray_ty.count_fields()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm_expected_ty
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, expected_ty)| {
|
||||||
|
(expected_ty.0, expected_ty.1, llvm_ndarray_ty.get_field_type_at_index(i as u32).unwrap())
|
||||||
|
})
|
||||||
|
.try_for_each(|(field_name, expected_ty, actual_ty)| {
|
||||||
|
if field_name == fields.data.name() {
|
||||||
|
if actual_ty.is_pointer_type() {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Expected T* for `ContiguousNDArray.{field_name}`, got {actual_ty}"))
|
||||||
|
}
|
||||||
|
} else if expected_ty == actual_ty {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(format!("Expected {expected_ty} for `ContiguousNDArray.{field_name}`, got {actual_ty}"))
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an instance of [`StructFields`] containing all field accessors for this type.
|
||||||
|
#[must_use]
|
||||||
|
fn fields(
|
||||||
|
item: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> ContiguousNDArrayFields<'ctx> {
|
||||||
|
ContiguousNDArrayFields::new_typed(item, llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// See [`NDArrayType::fields`].
|
||||||
|
// TODO: Move this into e.g. StructProxyType
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> {
|
||||||
|
Self::fields(self.item, self.llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
fn llvm_type(
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
item: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> PointerType<'ctx> {
|
||||||
|
let field_tys =
|
||||||
|
Self::fields(item, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||||
|
|
||||||
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an instance of [`ContiguousNDArrayType`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn new<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
item: BasicTypeEnum<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
|
let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize);
|
||||||
|
|
||||||
|
Self { ty: llvm_cndarray, item, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_unifier_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ty: Type,
|
||||||
|
) -> Self {
|
||||||
|
let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
|
|
||||||
|
let llvm_dtype = ctx.get_llvm_type(generator, dtype);
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`.
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_type(
|
||||||
|
ptr_ty: PointerType<'ctx>,
|
||||||
|
item: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
Self { ty: ptr_ty, item, llvm_usize }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn alloca<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self as ProxyType<'ctx>>::Value {
|
||||||
|
<Self as ProxyType<'ctx>>::Value::from_pointer_value(
|
||||||
|
self.raw_alloca(generator, ctx, name),
|
||||||
|
self.item,
|
||||||
|
self.llvm_usize,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> {
|
||||||
|
type Base = PointerType<'ctx>;
|
||||||
|
type Value = ContiguousNDArrayValue<'ctx>;
|
||||||
|
|
||||||
|
fn is_type<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: impl BasicType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() {
|
||||||
|
<Self as ProxyType<'ctx>>::is_representable(generator, ctx, ty)
|
||||||
|
} else {
|
||||||
|
Err(format!("Expected pointer type, got {llvm_ty:?}"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_representable<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &G,
|
||||||
|
ctx: &'ctx Context,
|
||||||
|
llvm_ty: Self::Base,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
Self::is_representable(llvm_ty, generator.get_size_type(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn raw_alloca<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> <Self::Value as ProxyValue<'ctx>>::Base {
|
||||||
|
generator
|
||||||
|
.gen_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn array_alloca<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
size: IntValue<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> ArraySliceValue<'ctx> {
|
||||||
|
generator
|
||||||
|
.gen_array_var_alloc(
|
||||||
|
ctx,
|
||||||
|
self.as_base_type().get_element_type().into_struct_type().into(),
|
||||||
|
size,
|
||||||
|
name,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_type(&self) -> Self::Base {
|
||||||
|
self.ty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ContiguousNDArrayType<'ctx>> for PointerType<'ctx> {
|
||||||
|
fn from(value: ContiguousNDArrayType<'ctx>) -> Self {
|
||||||
|
value.as_base_type()
|
||||||
|
}
|
||||||
|
}
|
@ -20,6 +20,9 @@ use crate::{
|
|||||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::Type,
|
||||||
};
|
};
|
||||||
|
pub use contiguous::*;
|
||||||
|
|
||||||
|
mod contiguous;
|
||||||
|
|
||||||
/// Proxy type for a `ndarray` type in LLVM.
|
/// Proxy type for a `ndarray` type in LLVM.
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||||
|
@ -103,6 +103,12 @@ where
|
|||||||
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
|
StructField { index, name, ty: ty.into(), _value_ty: PhantomData }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the name of this field.
|
||||||
|
#[must_use]
|
||||||
|
pub fn name(&self) -> &'static str {
|
||||||
|
self.name
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
||||||
/// {idx...}, i32 {self.index}`.
|
/// {idx...}, i32 {self.index}`.
|
||||||
pub fn ptr_by_array_gep(
|
pub fn ptr_by_array_gep(
|
||||||
|
206
nac3core/src/codegen/values/ndarray/contiguous.rs
Normal file
206
nac3core/src/codegen/values/ndarray/contiguous.rs
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
use inkwell::{
|
||||||
|
types::{BasicType, BasicTypeEnum, IntType},
|
||||||
|
values::{IntValue, PointerValue},
|
||||||
|
AddressSpace,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{ArrayLikeValue, NDArrayValue, ProxyValue};
|
||||||
|
use crate::codegen::{
|
||||||
|
stmt::gen_if_callback,
|
||||||
|
types::{structure::StructField, ContiguousNDArrayType, NDArrayType},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct ContiguousNDArrayValue<'ctx> {
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
item: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ContiguousNDArrayValue<'ctx> {
|
||||||
|
/// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is
|
||||||
|
/// not an instance.
|
||||||
|
pub fn is_representable(
|
||||||
|
value: PointerValue<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
) -> Result<(), String> {
|
||||||
|
<Self as ProxyValue<'ctx>>::Type::is_representable(value.get_type(), llvm_usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`].
|
||||||
|
#[must_use]
|
||||||
|
pub fn from_pointer_value(
|
||||||
|
ptr: PointerValue<'ctx>,
|
||||||
|
dtype: BasicTypeEnum<'ctx>,
|
||||||
|
llvm_usize: IntType<'ctx>,
|
||||||
|
name: Option<&'ctx str>,
|
||||||
|
) -> Self {
|
||||||
|
debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok());
|
||||||
|
|
||||||
|
Self { value: ptr, item: dtype, llvm_usize, name }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> {
|
||||||
|
self.get_type().get_fields().ndims
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) {
|
||||||
|
self.ndims_field().set(ctx, self.as_base_value(), value, self.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||||
|
self.get_type().get_fields().shape
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||||
|
self.shape_field().set(ctx, self.as_base_value(), value, self.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.shape_field().get(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> {
|
||||||
|
self.get_type().get_fields().data
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||||
|
self.data_field().set(ctx, self.as_base_value(), value, self.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
self.data_field().get(ctx, self.value, self.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> {
|
||||||
|
type Base = PointerValue<'ctx>;
|
||||||
|
type Type = ContiguousNDArrayType<'ctx>;
|
||||||
|
|
||||||
|
fn get_type(&self) -> Self::Type {
|
||||||
|
<Self as ProxyValue<'ctx>>::Type::from_type(
|
||||||
|
self.as_base_value().get_type(),
|
||||||
|
self.item,
|
||||||
|
self.llvm_usize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_base_value(&self) -> Self::Base {
|
||||||
|
self.value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> From<ContiguousNDArrayValue<'ctx>> for PointerValue<'ctx> {
|
||||||
|
fn from(value: ContiguousNDArrayValue<'ctx>) -> Self {
|
||||||
|
value.as_base_value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
/// Create a [`ContiguousNDArray`] from the contents of this ndarray.
|
||||||
|
///
|
||||||
|
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
|
||||||
|
///
|
||||||
|
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of
|
||||||
|
/// the returned [`ContiguousNDArray`] and copy contents of this ndarray to there.
|
||||||
|
///
|
||||||
|
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`ContiguousNDArray`]
|
||||||
|
/// will share memory with this ndarray.
|
||||||
|
///
|
||||||
|
/// The `item_model` sets the [`Model`] of the returned [`ContiguousNDArray`]'s `Item` model for type-safety, and
|
||||||
|
/// should match the `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics. Use model [`Any`]
|
||||||
|
/// if you don't care/cannot know the [`Model`] in advance.
|
||||||
|
pub fn make_contiguous_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> ContiguousNDArrayValue<'ctx> {
|
||||||
|
let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype)
|
||||||
|
.alloca(generator, ctx, self.name);
|
||||||
|
|
||||||
|
// Set ndims and shape.
|
||||||
|
let ndims = self
|
||||||
|
.ndims
|
||||||
|
.map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false));
|
||||||
|
result.store_ndims(ctx, ndims);
|
||||||
|
|
||||||
|
let shape = self.shape();
|
||||||
|
result.store_shape(ctx, shape.base_ptr(ctx, generator));
|
||||||
|
|
||||||
|
gen_if_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|generator, ctx| Ok(self.is_c_contiguous(generator, ctx)),
|
||||||
|
|_, ctx| {
|
||||||
|
// This ndarray is contiguous.
|
||||||
|
let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name);
|
||||||
|
let data = ctx
|
||||||
|
.builder
|
||||||
|
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
|
||||||
|
.unwrap();
|
||||||
|
result.store_data(ctx, data);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|generator, ctx| {
|
||||||
|
// This ndarray is not contiguous. Do a full-copy on `data`. `make_copy` produces an
|
||||||
|
// ndarray with contiguous `data`.
|
||||||
|
let copied_ndarray = self.make_copy(generator, ctx);
|
||||||
|
let data = copied_ndarray.data().base_ptr(ctx, generator);
|
||||||
|
let data = ctx
|
||||||
|
.builder
|
||||||
|
.build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "")
|
||||||
|
.unwrap();
|
||||||
|
result.store_data(ctx, data);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an [`NDArrayObject`] from a [`ContiguousNDArray`].
|
||||||
|
///
|
||||||
|
/// The operation is super cheap. The newly created [`NDArrayObject`] will share the
|
||||||
|
/// same memory as the [`ContiguousNDArray`].
|
||||||
|
///
|
||||||
|
/// `ndims` has to be provided as [`NDArrayObject`] requires a statically known `ndims` value, despite
|
||||||
|
/// the fact that the information should be contained within the [`ContiguousNDArray`].
|
||||||
|
pub fn from_contiguous_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
carray: ContiguousNDArrayValue<'ctx>,
|
||||||
|
ndims: u64,
|
||||||
|
) -> Self {
|
||||||
|
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
|
||||||
|
|
||||||
|
// Allocate the resulting ndarray.
|
||||||
|
let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims))
|
||||||
|
.construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
carray.llvm_usize.const_int(ndims, false),
|
||||||
|
carray.name,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Copy shape and update strides
|
||||||
|
let shape = carray.load_shape(ctx);
|
||||||
|
ndarray.copy_shape_from_array(generator, ctx, shape);
|
||||||
|
ndarray.set_strides_contiguous(generator, ctx);
|
||||||
|
|
||||||
|
// Share data
|
||||||
|
let data = carray.load_data(ctx);
|
||||||
|
ndarray.store_data(
|
||||||
|
ctx,
|
||||||
|
ctx.builder
|
||||||
|
.build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "")
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
}
|
@ -16,6 +16,9 @@ use crate::codegen::{
|
|||||||
types::{structure::StructField, NDArrayType},
|
types::{structure::StructField, NDArrayType},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
pub use contiguous::*;
|
||||||
|
|
||||||
|
mod contiguous;
|
||||||
|
|
||||||
/// Proxy type for accessing an `NDArray` value in LLVM.
|
/// Proxy type for accessing an `NDArray` value in LLVM.
|
||||||
#[derive(Copy, Clone)]
|
#[derive(Copy, Clone)]
|
||||||
@ -362,6 +365,29 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||||||
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
|
irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn make_copy<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Self {
|
||||||
|
let clone = self.get_type().construct_uninitialized(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
self.ndims.map_or_else(
|
||||||
|
|| self.load_ndims(ctx),
|
||||||
|
|ndims| self.llvm_usize.const_int(ndims, false),
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
let shape = self.shape();
|
||||||
|
clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator));
|
||||||
|
unsafe { clone.create_data(generator, ctx) };
|
||||||
|
clone.copy_data_from(generator, ctx, *self);
|
||||||
|
clone
|
||||||
|
}
|
||||||
|
|
||||||
/// Copy data from another ndarray.
|
/// Copy data from another ndarray.
|
||||||
///
|
///
|
||||||
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
|
/// This ndarray and `src` is that their `np.size()` should be the same. Their shapes
|
||||||
|
@ -1758,16 +1758,14 @@ def run() -> int32:
|
|||||||
test_ndarray_transpose()
|
test_ndarray_transpose()
|
||||||
test_ndarray_reshape()
|
test_ndarray_reshape()
|
||||||
|
|
||||||
test_ndarray_dot()
|
test_ndarray_cholesky()
|
||||||
|
test_ndarray_qr()
|
||||||
# test_ndarray_cholesky()
|
test_ndarray_svd()
|
||||||
# test_ndarray_qr()
|
test_ndarray_linalg_inv()
|
||||||
# test_ndarray_svd()
|
test_ndarray_pinv()
|
||||||
# test_ndarray_linalg_inv()
|
|
||||||
# test_ndarray_pinv()
|
|
||||||
# test_ndarray_matrix_power()
|
# test_ndarray_matrix_power()
|
||||||
# test_ndarray_det()
|
test_ndarray_det()
|
||||||
# test_ndarray_lu()
|
test_ndarray_lu()
|
||||||
# test_ndarray_schur()
|
test_ndarray_schur()
|
||||||
# test_ndarray_hessenberg()
|
test_ndarray_hessenberg()
|
||||||
return 0
|
return 0
|
||||||
|
Loading…
Reference in New Issue
Block a user