1
0
forked from M-Labs/nac3

[core] codegen: Cleanup builtin_fns.rs

- Unpack tuples directly in function argument
- Replace Vec parameters with slices
- Replace unwrap-transform with map-unwrap
This commit is contained in:
David Mak 2024-12-10 12:44:29 +08:00
parent 8e2b50df21
commit ad67a99c8f

View File

@ -43,11 +43,10 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -
pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(arg_ty, arg): (Type, BasicValueEnum<'ctx>),
) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let range_ty = ctx.primitives.range;
let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range"));
@ -105,12 +104,11 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -168,13 +166,11 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => {
debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,]
@ -231,13 +227,11 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -310,13 +304,11 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => {
debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,]
@ -378,13 +370,11 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => {
debug_assert!([
@ -445,14 +435,12 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
ret_elem_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "round";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type();
Ok(match n {
@ -492,14 +480,12 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_round";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::FloatValue(n) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
@ -533,14 +519,12 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "bool";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
Ok(match n {
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
@ -603,14 +587,12 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
ret_elem_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "floor";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty);
Ok(match n {
@ -654,14 +636,12 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
ret_elem_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ceil";
let llvm_usize = generator.get_size_type(ctx.ctx);
let (n_ty, n) = n;
let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty);
Ok(match n {
@ -704,14 +684,11 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
/// Invokes the `min` builtin function.
pub fn call_min<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
m: (Type, BasicValueEnum<'ctx>),
n: (Type, BasicValueEnum<'ctx>),
(m_ty, m): (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> BasicValueEnum<'ctx> {
const FN_NAME: &str = "min";
let (m_ty, m) = m;
let (n_ty, n) = n;
let common_ty = if ctx.unifier.unioned(m_ty, n_ty) {
m_ty
} else {
@ -754,14 +731,11 @@ pub fn call_min<'ctx>(
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_minimum";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None };
Ok(match (x1, x2) {
@ -836,14 +810,11 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
/// Invokes the `max` builtin function.
pub fn call_max<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
m: (Type, BasicValueEnum<'ctx>),
n: (Type, BasicValueEnum<'ctx>),
(m_ty, m): (Type, BasicValueEnum<'ctx>),
(n_ty, n): (Type, BasicValueEnum<'ctx>),
) -> BasicValueEnum<'ctx> {
const FN_NAME: &str = "max";
let (m_ty, m) = m;
let (n_ty, n) = n;
let common_ty = if ctx.unifier.unioned(m_ty, n_ty) {
m_ty
} else {
@ -887,7 +858,7 @@ pub fn call_max<'ctx>(
pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
a: (Type, BasicValueEnum<'ctx>),
(a_ty, a): (Type, BasicValueEnum<'ctx>),
fn_name: &str,
) -> Result<BasicValueEnum<'ctx>, String> {
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
@ -895,7 +866,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
let llvm_int64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let (a_ty, a) = a;
Ok(match a {
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
debug_assert!([
@ -1016,14 +986,11 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_maximum";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None };
Ok(match (x1, x2) {
@ -1163,6 +1130,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
n: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "abs";
helper_call_numpy_unary_elementwise(
generator,
ctx,
@ -1473,14 +1441,11 @@ create_helper_call_numpy_unary_elementwise_float_to_float!(
pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_arctan2";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
@ -1540,14 +1505,11 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_copysign";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
@ -1607,14 +1569,11 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_fmax";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
@ -1674,14 +1633,11 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_fmin";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
@ -1741,14 +1697,11 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_ldexp";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
@ -1797,14 +1750,11 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_hypot";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
@ -1864,14 +1814,11 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_nextafter";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
Ok(match (x1, x2) {
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
@ -1930,14 +1877,13 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
fn build_output_struct<'ctx>(
ctx: &mut CodeGenContext<'ctx, '_>,
out_matrices: Vec<BasicValueEnum<'ctx>>,
out_matrices: &[BasicValueEnum<'ctx>],
) -> PointerValue<'ctx> {
let field_ty =
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec();
let out_ty = ctx.ctx.struct_type(&field_ty, false);
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
for (i, v) in out_matrices.into_iter().enumerate() {
for (i, v) in out_matrices.iter().enumerate() {
unsafe {
let ptr = ctx
.builder
@ -1950,7 +1896,7 @@ fn build_output_struct<'ctx>(
"",
)
.unwrap();
ctx.builder.build_store(ptr, v).unwrap();
ctx.builder.build_store(ptr, *v).unwrap();
}
}
out_ptr
@ -1960,10 +1906,10 @@ fn build_output_struct<'ctx>(
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_cholesky";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -1987,9 +1933,9 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
Ok(out)
@ -2002,10 +1948,10 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_qr";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -2030,17 +1976,17 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
let out_ptr = build_output_struct(ctx, &[out_q, out_r]);
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
} else {
@ -2052,10 +1998,10 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_svd";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -2081,21 +2027,21 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
let out_ptr = build_output_struct(ctx, &[out_u, out_s, out_vh]);
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
} else {
@ -2107,10 +2053,10 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_inv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -2134,9 +2080,9 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
Ok(out)
@ -2149,10 +2095,10 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_pinv";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -2177,9 +2123,9 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
Ok(out)
@ -2192,10 +2138,10 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_lu";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -2221,17 +2167,17 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
let out_ptr = build_output_struct(ctx, &[out_l, out_u]);
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
@ -2242,12 +2188,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
x2: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let (x2_ty, x2) = x2;
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -2290,11 +2235,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
};
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
Ok(out)
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2305,10 +2251,9 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "np_linalg_matrix_power";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(_) = x1 {
@ -2327,7 +2272,9 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
&[llvm_usize.const_int(1, false)],
)
.unwrap();
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
let res =
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
Ok(res)
@ -2340,10 +2287,10 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_schur";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -2362,17 +2309,17 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
.into_int_value()
};
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
let out_ptr = build_output_struct(ctx, &[out_t, out_z]);
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
} else {
unsupported_type(ctx, FN_NAME, &[x1_ty])
@ -2383,10 +2330,10 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "sp_linalg_hessenberg";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
@ -2405,16 +2352,17 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
.into_int_value()
};
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
.unwrap()
.as_base_value()
.as_basic_value_enum();
.map(NDArrayValue::into)
.map(PointerValue::into)
.unwrap();
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
let out_ptr = build_output_struct(ctx, &[out_h, out_q]);
Ok(ctx
.builder
.build_load(out_ptr, "Hessenberg_decomposition_result")