forked from M-Labs/nac3
core/ndstrides: implement nalgebra functions
This commit is contained in:
parent
1562a938a1
commit
54f7e1edfd
|
@ -1,17 +1,14 @@
|
|||
use inkwell::types::BasicTypeEnum;
|
||||
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
|
||||
use inkwell::values::{BasicValue, BasicValueEnum, IntValue};
|
||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::classes::{
|
||||
NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
};
|
||||
use crate::codegen::classes::RangeValue;
|
||||
use crate::codegen::expr::destructure_range;
|
||||
use crate::codegen::irrt::calculate_len_for_slice_range;
|
||||
use crate::codegen::object::ndarray::{NDArrayOut, ScalarOrNDArray};
|
||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, CodeGenContext, CodeGenerator};
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
use crate::typecheck::typedef::Type;
|
||||
|
||||
use super::model::*;
|
||||
|
@ -1607,500 +1604,332 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
|||
Ok(result.to_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
|
||||
fn build_output_struct<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
out_matrices: Vec<BasicValueEnum<'ctx>>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let field_ty =
|
||||
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
|
||||
let out_ty = ctx.ctx.struct_type(&field_ty, false);
|
||||
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
|
||||
|
||||
for (i, v) in out_matrices.into_iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_in_bounds_gep(
|
||||
out_ptr,
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(i as u64, false),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(ptr, v).unwrap();
|
||||
}
|
||||
}
|
||||
out_ptr
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_cholesky` linalg function
|
||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_cholesky";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let out = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2);
|
||||
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
out.create_data(generator, ctx);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_np_linalg_cholesky(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
out_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
Ok(out.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_qr` linalg function
|
||||
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_qr";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let x1_shape = x1.instance.get(generator, ctx, |f| f.shape);
|
||||
let d0 = x1_shape.get_index_const(generator, ctx, 0);
|
||||
let d1 = x1_shape.get_index_const(generator, ctx, 1);
|
||||
let dk =
|
||||
Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None));
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unimplemented!("{FN_NAME} operates on float type NdArrays only");
|
||||
};
|
||||
let q = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, dk]);
|
||||
q.create_data(generator, ctx);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
let r = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk, d1]);
|
||||
r.create_data(generator, ctx);
|
||||
|
||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let q_c = q.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
let r_c = r.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_np_linalg_qr(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
q_c.value.as_basic_value_enum(),
|
||||
r_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
let q = q.to_any(ctx);
|
||||
let r = r.to_any(ctx);
|
||||
let tuple = TupleObject::from_objects(generator, ctx, [q, r]);
|
||||
Ok(tuple.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_svd` linalg function
|
||||
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_svd";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let x1_shape = x1.instance.get(generator, ctx, |f| f.shape);
|
||||
let d0 = x1_shape.get_index_const(generator, ctx, 0);
|
||||
let d1 = x1_shape.get_index_const(generator, ctx, 1);
|
||||
let dk =
|
||||
Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None));
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let u = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, d0]);
|
||||
u.create_data(generator, ctx);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let s = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk]);
|
||||
s.create_data(generator, ctx);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
let vh = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d1, d1]);
|
||||
vh.create_data(generator, ctx);
|
||||
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let u_c = u.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
let s_c = s.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
let vh_c = vh.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_np_linalg_svd(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
u_c.value.as_basic_value_enum(),
|
||||
s_c.value.as_basic_value_enum(),
|
||||
vh_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
let u = u.to_any(ctx);
|
||||
let s = s.to_any(ctx);
|
||||
let vh = vh.to_any(ctx);
|
||||
let tuple = TupleObject::from_objects(generator, ctx, [u, s, vh]);
|
||||
Ok(tuple.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_inv` linalg function
|
||||
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_inv";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let out = NDArrayObject::alloca(generator, ctx, x1.dtype, 2);
|
||||
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
out.create_data(generator, ctx);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_np_linalg_inv(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
out_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
Ok(out.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_pinv` linalg function
|
||||
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_pinv";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let x1_shape = x1.instance.get(generator, ctx, |f| f.shape);
|
||||
let d0 = x1_shape.get_index_const(generator, ctx, 0);
|
||||
let d1 = x1_shape.get_index_const(generator, ctx, 1);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let out = NDArrayObject::alloca_dynamic_shape(generator, ctx, x1.dtype, &[d1, d0]);
|
||||
out.create_data(generator, ctx);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_np_linalg_pinv(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
out_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
Ok(out.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_lu` linalg function
|
||||
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_lu";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let x1_shape = x1.instance.get(generator, ctx, |f| f.shape);
|
||||
let d0 = x1_shape.get_index_const(generator, ctx, 0);
|
||||
let d1 = x1_shape.get_index_const(generator, ctx, 1);
|
||||
let dk =
|
||||
Int(SizeT).believe_value(llvm_intrinsics::call_int_smin(ctx, d0.value, d1.value, None));
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let l = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[d0, dk]);
|
||||
l.create_data(generator, ctx);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let u = NDArrayObject::alloca_dynamic_shape(generator, ctx, ctx.primitives.float, &[dk, d1]);
|
||||
u.create_data(generator, ctx);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let dim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let l_c = l.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
let u_c = u.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_sp_linalg_lu(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
l_c.value.as_basic_value_enum(),
|
||||
u_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
let l = l.to_any(ctx);
|
||||
let u = u.to_any(ctx);
|
||||
let tuple = TupleObject::from_objects(generator, ctx, [l, u]);
|
||||
Ok(tuple.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_matrix_power` linalg function
|
||||
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
// x2 is a float, but we are promoting this to a 1D ndarray (.shape == [1]) for uniformity in function call.
|
||||
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||
let x2 = AnyObject { ty: ctx.primitives.float, value: x2 };
|
||||
let x2 = NDArrayObject::make_unsized(generator, ctx, x2); // x2.shape == []
|
||||
let x2 = x2.atleast_nd(generator, ctx, 1); // x2.shape == [1]
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let (BasicValueEnum::PointerValue(n1), BasicValueEnum::FloatValue(n2)) = (x1, x2) {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let out = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2);
|
||||
out.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
out.create_data(generator, ctx);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
||||
};
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let x2_c = x2.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
let out_c = out.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_np_linalg_matrix_power(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
x2_c.value.as_basic_value_enum(),
|
||||
out_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let n2_array = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
unsafe {
|
||||
n2_array.data().set_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
n2.as_basic_value_enum(),
|
||||
);
|
||||
};
|
||||
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||
|
||||
let outdim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let outdim1 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None)
|
||||
.into_int_value()
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
}
|
||||
Ok(out.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `np_linalg_det` linalg function
|
||||
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let BasicValueEnum::PointerValue(_) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
// The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call.
|
||||
let det = NDArrayObject::alloca_constant_shape(generator, ctx, ctx.primitives.float, &[1]);
|
||||
det.create_data(generator, ctx);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let out_c = det.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_np_linalg_det(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
out_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
// Changing second parameter to a `NDArray` for uniformity in function call
|
||||
let out = numpy::create_ndarray_const_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
||||
let res =
|
||||
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||
Ok(res)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
// Get the determinant out of `out`
|
||||
let zero = Int(SizeT).const_0(generator, ctx.ctx);
|
||||
let det = det.get_nth_scalar(generator, ctx, zero);
|
||||
Ok(det.value)
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_schur` linalg function
|
||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_schur";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
assert_eq!(x1.ndims, 2);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let t = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2);
|
||||
t.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
t.create_data(generator, ctx);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let z = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2);
|
||||
z.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
z.create_data(generator, ctx);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let t_c = t.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
let z_c = z.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_sp_linalg_schur(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
t_c.value.as_basic_value_enum(),
|
||||
z_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
let t = t.to_any(ctx);
|
||||
let z = z.to_any(ctx);
|
||||
let tuple = TupleObject::from_objects(generator, ctx, [t, z]);
|
||||
Ok(tuple.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
/// Invokes the `sp_linalg_hessenberg` linalg function
|
||||
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_hessenberg";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1 };
|
||||
let x1 = NDArrayObject::from_object(generator, ctx, x1);
|
||||
assert_eq!(x1.ndims, 2);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
|
||||
let h = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2);
|
||||
h.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
h.create_data(generator, ctx);
|
||||
|
||||
let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty]);
|
||||
};
|
||||
let q = NDArrayObject::alloca(generator, ctx, ctx.primitives.float, 2);
|
||||
q.copy_shape_from_ndarray(generator, ctx, x1);
|
||||
q.create_data(generator, ctx);
|
||||
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let x1_c = x1.make_contiguous_ndarray(generator, ctx, Float(Float64));
|
||||
let h_c = h.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
let q_c = q.make_contiguous_ndarray(generator, ctx, Float(Float64)); // Shares `data`.
|
||||
extern_fns::call_sp_linalg_hessenberg(
|
||||
ctx,
|
||||
x1_c.value.as_basic_value_enum(),
|
||||
h_c.value.as_basic_value_enum(),
|
||||
q_c.value.as_basic_value_enum(),
|
||||
None,
|
||||
);
|
||||
|
||||
let dim0 = unsafe {
|
||||
n1.dim_sizes()
|
||||
.get_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
.into_int_value()
|
||||
};
|
||||
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_load(out_ptr, "Hessenberg_decomposition_result")
|
||||
.map(Into::into)
|
||||
.unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
let h = h.to_any(ctx);
|
||||
let q = q.to_any(ctx);
|
||||
let tuple = TupleObject::from_objects(generator, ctx, [h, q]);
|
||||
Ok(tuple.value.as_basic_value_enum())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue