diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 20d3500..3e6a37c 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -43,11 +43,10 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) - pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (arg_ty, arg): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); let range_ty = ctx.primitives.range; - let (arg_ty, arg) = n; Ok(if ctx.unifier.unioned(arg_ty, range_ty) { let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range")); @@ -105,12 +104,11 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); @@ -168,13 +166,11 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; - Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] @@ -231,13 +227,11 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; - Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); @@ -310,13 +304,11 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_i64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; - Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => { debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,] @@ -378,13 +370,11 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { let llvm_f64 = ctx.ctx.f64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; - Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => { debug_assert!([ @@ -445,14 +435,12 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ret_elem_ty: Type, ) -> Result, String> { const FN_NAME: &str = "round"; let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type(); Ok(match n { @@ -492,14 +480,12 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_round"; let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; - Ok(match n { BasicValueEnum::FloatValue(n) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float)); @@ -533,14 +519,12 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "bool"; let llvm_usize = generator.get_size_type(ctx.ctx); - let (n_ty, n) = n; - Ok(match n { BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => { debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool)); @@ -603,14 +587,12 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ret_elem_ty: Type, ) -> Result, String> { const FN_NAME: &str = "floor"; let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); Ok(match n { @@ -654,14 +636,12 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - n: (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ret_elem_ty: Type, ) -> Result, String> { const FN_NAME: &str = "ceil"; let llvm_usize = generator.get_size_type(ctx.ctx); - - let (n_ty, n) = n; let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty); Ok(match n { @@ -704,14 +684,11 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( /// Invokes the `min` builtin function. pub fn call_min<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - m: (Type, BasicValueEnum<'ctx>), - n: (Type, BasicValueEnum<'ctx>), + (m_ty, m): (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> BasicValueEnum<'ctx> { const FN_NAME: &str = "min"; - let (m_ty, m) = m; - let (n_ty, n) = n; - let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { m_ty } else { @@ -754,14 +731,11 @@ pub fn call_min<'ctx>( pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_minimum"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; Ok(match (x1, x2) { @@ -836,14 +810,11 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( /// Invokes the `max` builtin function. pub fn call_max<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - m: (Type, BasicValueEnum<'ctx>), - n: (Type, BasicValueEnum<'ctx>), + (m_ty, m): (Type, BasicValueEnum<'ctx>), + (n_ty, n): (Type, BasicValueEnum<'ctx>), ) -> BasicValueEnum<'ctx> { const FN_NAME: &str = "max"; - let (m_ty, m) = m; - let (n_ty, n) = n; - let common_ty = if ctx.unifier.unioned(m_ty, n_ty) { m_ty } else { @@ -887,7 +858,7 @@ pub fn call_max<'ctx>( pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - a: (Type, BasicValueEnum<'ctx>), + (a_ty, a): (Type, BasicValueEnum<'ctx>), fn_name: &str, ) -> Result, String> { debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name)); @@ -895,7 +866,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( let llvm_int64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let (a_ty, a) = a; Ok(match a { BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { debug_assert!([ @@ -1016,14 +986,11 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_maximum"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None }; Ok(match (x1, x2) { @@ -1163,6 +1130,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "abs"; + helper_call_numpy_unary_elementwise( generator, ctx, @@ -1473,14 +1441,11 @@ create_helper_call_numpy_unary_elementwise_float_to_float!( pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_arctan2"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); @@ -1540,14 +1505,11 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_copysign"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); @@ -1607,14 +1569,11 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmax"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); @@ -1674,14 +1633,11 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_fmin"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); @@ -1741,14 +1697,11 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_ldexp"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => { debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); @@ -1797,14 +1750,11 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_hypot"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); @@ -1864,14 +1814,11 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_nextafter"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; - Ok(match (x1, x2) { (BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => { debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float)); @@ -1930,14 +1877,13 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( /// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it fn build_output_struct<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, - out_matrices: Vec>, + out_matrices: &[BasicValueEnum<'ctx>], ) -> PointerValue<'ctx> { - let field_ty = - out_matrices.iter().map(BasicValueEnum::get_type).collect::>(); + let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec(); let out_ty = ctx.ctx.struct_type(&field_ty, false); let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap(); - for (i, v) in out_matrices.into_iter().enumerate() { + for (i, v) in out_matrices.iter().enumerate() { unsafe { let ptr = ctx .builder @@ -1950,7 +1896,7 @@ fn build_output_struct<'ctx>( "", ) .unwrap(); - ctx.builder.build_store(ptr, v).unwrap(); + ctx.builder.build_store(ptr, *v).unwrap(); } } out_ptr @@ -1960,10 +1906,10 @@ fn build_output_struct<'ctx>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_linalg_cholesky"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -1987,9 +1933,9 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( }; let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_np_linalg_cholesky(ctx, x1, out, None); Ok(out) @@ -2002,10 +1948,10 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_linalg_qr"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -2030,17 +1976,17 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None); - let out_ptr = build_output_struct(ctx, vec![out_q, out_r]); + let out_ptr = build_output_struct(ctx, &[out_q, out_r]); Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) } else { @@ -2052,10 +1998,10 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_linalg_svd"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -2081,21 +2027,21 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); - let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]); + let out_ptr = build_output_struct(ctx, &[out_u, out_s, out_vh]); Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) } else { @@ -2107,10 +2053,10 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_linalg_inv"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -2134,9 +2080,9 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( }; let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_np_linalg_inv(ctx, x1, out, None); Ok(out) @@ -2149,10 +2095,10 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_linalg_pinv"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -2177,9 +2123,9 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( }; let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_np_linalg_pinv(ctx, x1, out, None); Ok(out) @@ -2192,10 +2138,10 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "sp_linalg_lu"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -2221,17 +2167,17 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None); - let out_ptr = build_output_struct(ctx, vec![out_l, out_u]); + let out_ptr = build_output_struct(ctx, &[out_l, out_u]); Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) } else { unsupported_type(ctx, FN_NAME, &[x1_ty]) @@ -2242,12 +2188,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), - x2: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), + (x2_ty, x2): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; - let (x2_ty, x2) = x2; + let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap(); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -2290,11 +2235,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( }; let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None); + Ok(out) } else { unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]) @@ -2305,10 +2251,9 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "np_linalg_matrix_power"; - let (x1_ty, x1) = x1; let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(_) = x1 { @@ -2327,7 +2272,9 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( &[llvm_usize.const_int(1, false)], ) .unwrap(); + extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None); + let res = unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; Ok(res) @@ -2340,10 +2287,10 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "sp_linalg_schur"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -2362,17 +2309,17 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( .into_int_value() }; let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None); - let out_ptr = build_output_struct(ctx, vec![out_t, out_z]); + let out_ptr = build_output_struct(ctx, &[out_t, out_z]); Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) } else { unsupported_type(ctx, FN_NAME, &[x1_ty]) @@ -2383,10 +2330,10 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - x1: (Type, BasicValueEnum<'ctx>), + (x1_ty, x1): (Type, BasicValueEnum<'ctx>), ) -> Result, String> { const FN_NAME: &str = "sp_linalg_hessenberg"; - let (x1_ty, x1) = x1; + let llvm_usize = generator.get_size_type(ctx.ctx); if let BasicValueEnum::PointerValue(n1) = x1 { @@ -2405,16 +2352,17 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( .into_int_value() }; let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .unwrap() - .as_base_value() - .as_basic_value_enum(); + .map(NDArrayValue::into) + .map(PointerValue::into) + .unwrap(); + extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None); - let out_ptr = build_output_struct(ctx, vec![out_h, out_q]); + let out_ptr = build_output_struct(ctx, &[out_h, out_q]); Ok(ctx .builder .build_load(out_ptr, "Hessenberg_decomposition_result")