diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index d0ce5fe8e3..80dbf0aab1 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1896,38 +1896,28 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_cholesky"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = llvm_ndarray_ty.map_value(n1, None); - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_np_linalg_cholesky(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + .construct_uninitialized(generator, ctx, None); + out.copy_shape_from_ndarray(generator, ctx, x1); + unsafe { out.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let out_c = out.make_contiguous_ndarray(generator, ctx); + extern_fns::call_np_linalg_cholesky( + ctx, + x1_c.as_base_value().into(), + out_c.as_base_value().into(), + None, + ); + Ok(out.as_base_value().into()) } /// Invokes the `np_linalg_qr` linalg function @@ -1940,44 +1930,45 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unimplemented!("{FN_NAME} operates on float type NdArrays only"); - }; + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - let n1 = llvm_ndarray_ty.map_value(n1, None); - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); - - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None); - - let out_ptr = build_output_struct(ctx, &[out_q, out_r]); - - Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let x1_shape = x1.shape(); + let d0 = + unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; + let d1 = unsafe { + x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + }; + let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); + + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + let q = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); + unsafe { q.create_data(generator, ctx) }; + + let r = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None); + unsafe { r.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let q_c = q.make_contiguous_ndarray(generator, ctx); + let r_c = r.make_contiguous_ndarray(generator, ctx); + + extern_fns::call_np_linalg_qr( + ctx, + x1_c.as_base_value().into(), + q_c.as_base_value().into(), + r_c.as_base_value().into(), + None, + ); + + let q = q.as_base_value().into(); + let r = r.as_base_value().into(); + let out_ptr = build_output_struct(ctx, &[q, r]); + Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap()) } /// Invokes the `np_linalg_svd` linalg function @@ -1990,49 +1981,54 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - let n1 = llvm_ndarray_ty.map_value(n1, None); - - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); - - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None); - - let out_ptr = build_output_struct(ctx, &[out_u, out_s, out_vh]); - - Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let x1_shape = x1.shape(); + let d0 = + unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; + let d1 = unsafe { + x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + }; + let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); + + let out_ndarray1_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)); + let out_ndarray2_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + + let u = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d0, d0], None); + unsafe { u.create_data(generator, ctx) }; + + let s = out_ndarray1_ty.construct_dyn_shape(generator, ctx, &[dk], None); + unsafe { s.create_data(generator, ctx) }; + + let vh = out_ndarray2_ty.construct_dyn_shape(generator, ctx, &[d1, d1], None); + unsafe { vh.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let u_c = u.make_contiguous_ndarray(generator, ctx); + let s_c = s.make_contiguous_ndarray(generator, ctx); + let vh_c = vh.make_contiguous_ndarray(generator, ctx); + + extern_fns::call_np_linalg_svd( + ctx, + x1_c.as_base_value().into(), + u_c.as_base_value().into(), + s_c.as_base_value().into(), + vh_c.as_base_value().into(), + None, + ); + + let u = u.as_base_value().into(); + let s = s.as_base_value().into(); + let vh = vh.as_base_value().into(); + let out_ptr = build_output_struct(ctx, &[u, s, vh]); + + Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap()) } /// Invokes the `np_linalg_inv` linalg function @@ -2043,38 +2039,29 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "np_linalg_inv"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = llvm_ndarray_ty.map_value(n1, None); - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_np_linalg_inv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + .construct_uninitialized(generator, ctx, None); + out.copy_shape_from_ndarray(generator, ctx, x1); + unsafe { out.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let out_c = out.make_contiguous_ndarray(generator, ctx); + extern_fns::call_np_linalg_inv( + ctx, + x1_c.as_base_value().into(), + out_c.as_base_value().into(), + None, + ); + + Ok(out.as_base_value().into()) } /// Invokes the `np_linalg_pinv` linalg function @@ -2087,37 +2074,35 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - let n1 = llvm_ndarray_ty.map_value(n1, None); - - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - - let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_np_linalg_pinv(ctx, x1, out, None); - Ok(out) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let x1_shape = x1.shape(); + let d0 = + unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; + let d1 = unsafe { + x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + }; + + let out = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)) + .construct_dyn_shape(generator, ctx, &[d0, d1], None); + unsafe { out.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let out_c = out.make_contiguous_ndarray(generator, ctx); + extern_fns::call_np_linalg_pinv( + ctx, + x1_c.as_base_value().into(), + out_c.as_base_value().into(), + None, + ); + + Ok(out.as_base_value().into()) } /// Invokes the `sp_linalg_lu` linalg function @@ -2130,44 +2115,45 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - let n1 = llvm_ndarray_ty.map_value(n1, None); - - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let dim1 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) - .into_int_value() - }; - let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None); - - let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None); - - let out_ptr = build_output_struct(ctx, &[out_l, out_u]); - Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let x1_shape = x1.shape(); + let d0 = + unsafe { x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; + let d1 = unsafe { + x1_shape.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + }; + let dk = llvm_intrinsics::call_int_smin(ctx, d0, d1, None); + + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + + let l = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[d0, dk], None); + unsafe { l.create_data(generator, ctx) }; + + let u = out_ndarray_ty.construct_dyn_shape(generator, ctx, &[dk, d1], None); + unsafe { u.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let l_c = l.make_contiguous_ndarray(generator, ctx); + let u_c = u.make_contiguous_ndarray(generator, ctx); + extern_fns::call_sp_linalg_lu( + ctx, + x1_c.as_base_value().into(), + l_c.as_base_value().into(), + u_c.as_base_value().into(), + None, + ); + + let l = l.as_base_value().into(); + let u = u.as_base_value().into(); + let out_ptr = build_output_struct(ctx, &[l, u]); + Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap()) } /// Invokes the `np_linalg_matrix_power` linalg function @@ -2242,31 +2228,32 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>( const FN_NAME: &str = "np_linalg_matrix_power"; let llvm_usize = generator.get_size_type(ctx.ctx); - if let BasicValueEnum::PointerValue(_) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - // Changing second parameter to a `NDArray` for uniformity in function call - let out = numpy::create_ndarray_const_shape( - generator, - ctx, - elem_ty, - &[llvm_usize.const_int(1, false)], - ) - .unwrap(); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); - extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None); - - let res = - unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - Ok(res) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + // The output is a float64, but we are using an ndarray (shape == [1]) for uniformity in function call. + let det = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(1)) + .construct_const_shape(generator, ctx, &[1], None); + unsafe { det.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let out_c = det.make_contiguous_ndarray(generator, ctx); + extern_fns::call_np_linalg_det( + ctx, + x1_c.as_base_value().into(), + out_c.as_base_value().into(), + None, + ); + + // Get the determinant out of `out` + let det = unsafe { det.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; + Ok(det) } /// Invokes the `sp_linalg_schur` linalg function @@ -2277,39 +2264,40 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "sp_linalg_schur"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + assert_eq!(x1.get_type().ndims(), Some(2)); - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = llvm_ndarray_ty.map_value(n1, None); - - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None); - - let out_ptr = build_output_struct(ctx, &[out_t, out_z]); - Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + + let t = out_ndarray_ty.construct_uninitialized(generator, ctx, None); + t.copy_shape_from_ndarray(generator, ctx, x1); + unsafe { t.create_data(generator, ctx) }; + + let z = out_ndarray_ty.construct_uninitialized(generator, ctx, None); + z.copy_shape_from_ndarray(generator, ctx, x1); + unsafe { z.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let t_c = t.make_contiguous_ndarray(generator, ctx); + let z_c = z.make_contiguous_ndarray(generator, ctx); + extern_fns::call_sp_linalg_schur( + ctx, + x1_c.as_base_value().into(), + t_c.as_base_value().into(), + z_c.as_base_value().into(), + None, + ); + + let t = t.as_base_value().into(); + let z = z.as_base_value().into(); + let out_ptr = build_output_struct(ctx, &[t, z]); + Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap()) } /// Invokes the `sp_linalg_hessenberg` linalg function @@ -2320,41 +2308,38 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "sp_linalg_hessenberg"; - let llvm_usize = generator.get_size_type(ctx.ctx); + let BasicValueEnum::PointerValue(x1) = x1 else { unsupported_type(ctx, FN_NAME, &[x1_ty]) }; - if let BasicValueEnum::PointerValue(n1) = x1 { - let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); - let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty); + let x1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(x1, None); + assert_eq!(x1.get_type().ndims(), Some(2)); - let BasicTypeEnum::FloatType(_) = llvm_ndarray_ty.element_type() else { - unsupported_type(ctx, FN_NAME, &[x1_ty]); - }; - - let n1 = llvm_ndarray_ty.map_value(n1, None); - - let dim0 = unsafe { - n1.shape() - .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - .into_int_value() - }; - let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0]) - .map(NDArrayValue::into) - .map(PointerValue::into) - .unwrap(); - - extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None); - - let out_ptr = build_output_struct(ctx, &[out_h, out_q]); - Ok(ctx - .builder - .build_load(out_ptr, "Hessenberg_decomposition_result") - .map(Into::into) - .unwrap()) - } else { - unsupported_type(ctx, FN_NAME, &[x1_ty]) + if !x1.get_type().element_type().is_float_type() { + unsupported_type(ctx, FN_NAME, &[x1_ty]); } + + let out_ndarray_ty = NDArrayType::new(generator, ctx.ctx, ctx.ctx.f64_type().into(), Some(2)); + + let h = out_ndarray_ty.construct_uninitialized(generator, ctx, None); + h.copy_shape_from_ndarray(generator, ctx, x1); + unsafe { h.create_data(generator, ctx) }; + + let q = out_ndarray_ty.construct_uninitialized(generator, ctx, None); + q.copy_shape_from_ndarray(generator, ctx, x1); + unsafe { q.create_data(generator, ctx) }; + + let x1_c = x1.make_contiguous_ndarray(generator, ctx); + let h_c = h.make_contiguous_ndarray(generator, ctx); + let q_c = q.make_contiguous_ndarray(generator, ctx); + extern_fns::call_sp_linalg_hessenberg( + ctx, + x1_c.as_base_value().into(), + h_c.as_base_value().into(), + q_c.as_base_value().into(), + None, + ); + + let h = h.as_base_value().into(); + let q = q.as_base_value().into(); + let out_ptr = build_output_struct(ctx, &[h, q]); + Ok(ctx.builder.build_load(out_ptr, "Hessenberg_decomposition_result").map(Into::into).unwrap()) } diff --git a/nac3core/src/codegen/types/ndarray/contiguous.rs b/nac3core/src/codegen/types/ndarray/contiguous.rs new file mode 100644 index 0000000000..4401cb62af --- /dev/null +++ b/nac3core/src/codegen/types/ndarray/contiguous.rs @@ -0,0 +1,257 @@ +use inkwell::{ + context::Context, + types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; +use itertools::Itertools; + +use nac3core_derive::StructFields; + +use crate::{ + codegen::{ + types::{ + structure::{ + check_struct_type_matches_fields, FieldIndexCounter, StructField, StructFields, + }, + ProxyType, + }, + values::{ndarray::ContiguousNDArrayValue, ArraySliceValue, ProxyValue}, + CodeGenContext, CodeGenerator, + }, + toplevel::numpy::unpack_ndarray_var_tys, + typecheck::typedef::Type, +}; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ContiguousNDArrayType<'ctx> { + ty: PointerType<'ctx>, + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, +} + +#[derive(PartialEq, Eq, Clone, Copy, StructFields)] +pub struct ContiguousNDArrayFields<'ctx> { + #[value_type(usize)] + pub ndims: StructField<'ctx, IntValue<'ctx>>, + #[value_type(usize.ptr_type(AddressSpace::default()))] + pub shape: StructField<'ctx, PointerValue<'ctx>>, + #[value_type(i8_type().ptr_type(AddressSpace::default()))] + pub data: StructField<'ctx, PointerValue<'ctx>>, +} + +impl<'ctx> ContiguousNDArrayFields<'ctx> { + #[must_use] + pub fn new_typed(item: BasicTypeEnum<'ctx>, llvm_usize: IntType<'ctx>) -> Self { + let mut counter = FieldIndexCounter::default(); + + ContiguousNDArrayFields { + ndims: StructField::create(&mut counter, "ndims", llvm_usize), + shape: StructField::create( + &mut counter, + "shape", + llvm_usize.ptr_type(AddressSpace::default()), + ), + data: StructField::create(&mut counter, "data", item.ptr_type(AddressSpace::default())), + } + } +} + +impl<'ctx> ContiguousNDArrayType<'ctx> { + /// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not. + pub fn is_representable( + llvm_ty: PointerType<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + let ctx = llvm_ty.get_context(); + + let llvm_ty = llvm_ty.get_element_type(); + let AnyTypeEnum::StructType(llvm_ty) = llvm_ty else { + return Err(format!( + "Expected struct type for `ContiguousNDArray` type, got {llvm_ty}" + )); + }; + + let fields = ContiguousNDArrayFields::new(ctx, llvm_usize); + + check_struct_type_matches_fields( + fields, + llvm_ty, + "ContiguousNDArray", + &[(fields.data.name(), &|ty| { + if ty.is_pointer_type() { + Ok(()) + } else { + Err(format!("Expected T* for `ContiguousNDArray.data`, got {ty}")) + } + })], + ) + } + + /// Returns an instance of [`StructFields`] containing all field accessors for this type. + #[must_use] + fn fields( + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> ContiguousNDArrayFields<'ctx> { + ContiguousNDArrayFields::new_typed(item, llvm_usize) + } + + /// See [`NDArrayType::fields`]. + // TODO: Move this into e.g. StructProxyType + #[must_use] + pub fn get_fields(&self) -> ContiguousNDArrayFields<'ctx> { + Self::fields(self.item, self.llvm_usize) + } + + /// Creates an LLVM type corresponding to the expected structure of an `NDArray`. + #[must_use] + fn llvm_type( + ctx: &'ctx Context, + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> PointerType<'ctx> { + let field_tys = + Self::fields(item, llvm_usize).into_iter().map(|field| field.1).collect_vec(); + + ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default()) + } + + /// Creates an instance of [`ContiguousNDArrayType`]. + #[must_use] + pub fn new( + generator: &G, + ctx: &'ctx Context, + item: BasicTypeEnum<'ctx>, + ) -> Self { + let llvm_usize = generator.get_size_type(ctx); + let llvm_cndarray = Self::llvm_type(ctx, item, llvm_usize); + + Self { ty: llvm_cndarray, item, llvm_usize } + } + + /// Creates an [`ContiguousNDArrayType`] from a [unifier type][Type]. + #[must_use] + pub fn from_unifier_type( + generator: &G, + ctx: &mut CodeGenContext<'ctx, '_>, + ty: Type, + ) -> Self { + let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + + let llvm_dtype = ctx.get_llvm_type(generator, dtype); + let llvm_usize = generator.get_size_type(ctx.ctx); + + Self { ty: Self::llvm_type(ctx.ctx, llvm_dtype, llvm_usize), item: llvm_dtype, llvm_usize } + } + + /// Creates an [`ContiguousNDArrayType`] from a [`PointerType`] representing an `NDArray`. + #[must_use] + pub fn from_type( + ptr_ty: PointerType<'ctx>, + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Self { + debug_assert!(Self::is_representable(ptr_ty, llvm_usize).is_ok()); + + Self { ty: ptr_ty, item, llvm_usize } + } + + /// Allocates an instance of [`ContiguousNDArrayValue`] as if by calling `alloca` on the base type. + #[must_use] + pub fn alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + self.raw_alloca(generator, ctx, name), + self.item, + self.llvm_usize, + name, + ) + } + + /// Converts an existing value into a [`ContiguousNDArrayValue`]. + #[must_use] + pub fn map_value( + &self, + value: <>::Value as ProxyValue<'ctx>>::Base, + name: Option<&'ctx str>, + ) -> >::Value { + >::Value::from_pointer_value( + value, + self.item, + self.llvm_usize, + name, + ) + } +} + +impl<'ctx> ProxyType<'ctx> for ContiguousNDArrayType<'ctx> { + type Base = PointerType<'ctx>; + type Value = ContiguousNDArrayValue<'ctx>; + + fn is_type( + generator: &G, + ctx: &'ctx Context, + llvm_ty: impl BasicType<'ctx>, + ) -> Result<(), String> { + if let BasicTypeEnum::PointerType(ty) = llvm_ty.as_basic_type_enum() { + >::is_representable(generator, ctx, ty) + } else { + Err(format!("Expected pointer type, got {llvm_ty:?}")) + } + } + + fn is_representable( + generator: &G, + ctx: &'ctx Context, + llvm_ty: Self::Base, + ) -> Result<(), String> { + Self::is_representable(llvm_ty, generator.get_size_type(ctx)) + } + + fn raw_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + name: Option<&'ctx str>, + ) -> >::Base { + generator + .gen_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + name, + ) + .unwrap() + } + + fn array_alloca( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + size: IntValue<'ctx>, + name: Option<&'ctx str>, + ) -> ArraySliceValue<'ctx> { + generator + .gen_array_var_alloc( + ctx, + self.as_base_type().get_element_type().into_struct_type().into(), + size, + name, + ) + .unwrap() + } + + fn as_base_type(&self) -> Self::Base { + self.ty + } +} + +impl<'ctx> From> for PointerType<'ctx> { + fn from(value: ContiguousNDArrayType<'ctx>) -> Self { + value.as_base_type() + } +} diff --git a/nac3core/src/codegen/types/ndarray/mod.rs b/nac3core/src/codegen/types/ndarray/mod.rs index d485eeb47a..ae390e5d7e 100644 --- a/nac3core/src/codegen/types/ndarray/mod.rs +++ b/nac3core/src/codegen/types/ndarray/mod.rs @@ -20,6 +20,9 @@ use crate::{ toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys}, typecheck::typedef::Type, }; +pub use contiguous::*; + +mod contiguous; /// Proxy type for a `ndarray` type in LLVM. #[derive(Debug, PartialEq, Eq, Clone, Copy)] diff --git a/nac3core/src/codegen/types/structure.rs b/nac3core/src/codegen/types/structure.rs index 4622e9b29b..4d6dcaf73f 100644 --- a/nac3core/src/codegen/types/structure.rs +++ b/nac3core/src/codegen/types/structure.rs @@ -103,6 +103,12 @@ where StructField { index, name, ty: ty.into(), _value_ty: PhantomData } } + /// Returns the name of this field. + #[must_use] + pub fn name(&self) -> &'static str { + self.name + } + /// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32 /// {idx...}, i32 {self.index}`. pub fn ptr_by_array_gep( diff --git a/nac3core/src/codegen/values/ndarray/contiguous.rs b/nac3core/src/codegen/values/ndarray/contiguous.rs new file mode 100644 index 0000000000..87e2f1d82f --- /dev/null +++ b/nac3core/src/codegen/values/ndarray/contiguous.rs @@ -0,0 +1,202 @@ +use inkwell::{ + types::{BasicType, BasicTypeEnum, IntType}, + values::{IntValue, PointerValue}, + AddressSpace, +}; + +use super::{ArrayLikeValue, NDArrayValue, ProxyValue}; +use crate::codegen::{ + stmt::gen_if_callback, + types::{ + ndarray::{ContiguousNDArrayType, NDArrayType}, + structure::StructField, + }, + CodeGenContext, CodeGenerator, +}; + +#[derive(Copy, Clone)] +pub struct ContiguousNDArrayValue<'ctx> { + value: PointerValue<'ctx>, + item: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, +} + +impl<'ctx> ContiguousNDArrayValue<'ctx> { + /// Checks whether `value` is an instance of `ContiguousNDArray`, returning [Err] if `value` is + /// not an instance. + pub fn is_representable( + value: PointerValue<'ctx>, + llvm_usize: IntType<'ctx>, + ) -> Result<(), String> { + >::Type::is_representable(value.get_type(), llvm_usize) + } + + /// Creates an [`ContiguousNDArrayValue`] from a [`PointerValue`]. + #[must_use] + pub fn from_pointer_value( + ptr: PointerValue<'ctx>, + dtype: BasicTypeEnum<'ctx>, + llvm_usize: IntType<'ctx>, + name: Option<&'ctx str>, + ) -> Self { + debug_assert!(Self::is_representable(ptr, llvm_usize).is_ok()); + + Self { value: ptr, item: dtype, llvm_usize, name } + } + + fn ndims_field(&self) -> StructField<'ctx, IntValue<'ctx>> { + self.get_type().get_fields().ndims + } + + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + self.ndims_field().set(ctx, self.as_base_value(), value, self.name); + } + + fn shape_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().shape + } + + pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + self.shape_field().set(ctx, self.as_base_value(), value, self.name); + } + + pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.shape_field().get(ctx, self.value, self.name) + } + + fn data_field(&self) -> StructField<'ctx, PointerValue<'ctx>> { + self.get_type().get_fields().data + } + + pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + self.data_field().set(ctx, self.as_base_value(), value, self.name); + } + + pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + self.data_field().get(ctx, self.value, self.name) + } +} + +impl<'ctx> ProxyValue<'ctx> for ContiguousNDArrayValue<'ctx> { + type Base = PointerValue<'ctx>; + type Type = ContiguousNDArrayType<'ctx>; + + fn get_type(&self) -> Self::Type { + >::Type::from_type( + self.as_base_value().get_type(), + self.item, + self.llvm_usize, + ) + } + + fn as_base_value(&self) -> Self::Base { + self.value + } +} + +impl<'ctx> From> for PointerValue<'ctx> { + fn from(value: ContiguousNDArrayValue<'ctx>) -> Self { + value.as_base_value() + } +} + +impl<'ctx> NDArrayValue<'ctx> { + /// Create a [`ContiguousNDArrayValue`] from the contents of this ndarray. + /// + /// This function may or may not be expensive depending on if this ndarray has contiguous data. + /// + /// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the + /// `data` field of the returned [`ContiguousNDArrayValue`] and copy contents of this ndarray to + /// there. + /// + /// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created + /// [`ContiguousNDArrayValue`] will share memory with this ndarray. + pub fn make_contiguous_ndarray( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> ContiguousNDArrayValue<'ctx> { + let result = ContiguousNDArrayType::new(generator, ctx.ctx, self.dtype) + .alloca(generator, ctx, self.name); + + // Set ndims and shape. + let ndims = self + .ndims + .map_or_else(|| self.load_ndims(ctx), |ndims| self.llvm_usize.const_int(ndims, false)); + result.store_ndims(ctx, ndims); + + let shape = self.shape(); + result.store_shape(ctx, shape.base_ptr(ctx, generator)); + + gen_if_callback( + generator, + ctx, + |generator, ctx| Ok(self.is_c_contiguous(generator, ctx)), + |_, ctx| { + // This ndarray is contiguous. + let data = self.data_field(ctx).get(ctx, self.as_base_value(), self.name); + let data = ctx + .builder + .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") + .unwrap(); + result.store_data(ctx, data); + + Ok(()) + }, + |generator, ctx| { + // This ndarray is not contiguous. Do a full-copy on `data`. `make_copy` produces an + // ndarray with contiguous `data`. + let copied_ndarray = self.make_copy(generator, ctx); + let data = copied_ndarray.data().base_ptr(ctx, generator); + let data = ctx + .builder + .build_pointer_cast(data, result.item.ptr_type(AddressSpace::default()), "") + .unwrap(); + result.store_data(ctx, data); + + Ok(()) + }, + ) + .unwrap(); + + result + } + + /// Create an [`NDArrayValue`] from a [`ContiguousNDArrayValue`]. + /// + /// The operation is cheap. The newly created [`NDArrayValue`] will share the same memory as the + /// [`ContiguousNDArrayValue`]. + /// + /// `ndims` has to be provided as [`NDArrayValue`] requires a statically known `ndims` value, + /// despite the fact that the information should be contained within the + /// [`ContiguousNDArrayValue`]. + pub fn from_contiguous_ndarray( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + carray: ContiguousNDArrayValue<'ctx>, + ndims: u64, + ) -> Self { + // TODO: Debug assert `ndims == carray.ndims` to catch bugs. + + // Allocate the resulting ndarray. + let ndarray = NDArrayType::new(generator, ctx.ctx, carray.item, Some(ndims)) + .construct_uninitialized(generator, ctx, carray.name); + + // Copy shape and update strides + let shape = carray.load_shape(ctx); + ndarray.copy_shape_from_array(generator, ctx, shape); + ndarray.set_strides_contiguous(generator, ctx); + + // Share data + let data = carray.load_data(ctx); + ndarray.store_data( + ctx, + ctx.builder + .build_pointer_cast(data, ctx.ctx.i8_type().ptr_type(AddressSpace::default()), "") + .unwrap(), + ); + + ndarray + } +} diff --git a/nac3core/src/codegen/values/ndarray/mod.rs b/nac3core/src/codegen/values/ndarray/mod.rs index 61504c2d58..084c9b0315 100644 --- a/nac3core/src/codegen/values/ndarray/mod.rs +++ b/nac3core/src/codegen/values/ndarray/mod.rs @@ -16,6 +16,9 @@ use crate::codegen::{ types::{ndarray::NDArrayType, structure::StructField}, CodeGenContext, CodeGenerator, }; +pub use contiguous::*; + +mod contiguous; /// Proxy type for accessing an `NDArray` value in LLVM. #[derive(Copy, Clone)] @@ -362,6 +365,25 @@ impl<'ctx> NDArrayValue<'ctx> { irrt::ndarray::call_nac3_ndarray_set_strides_by_shape(generator, ctx, *self); } + #[must_use] + pub fn make_copy( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + let clone = if self.ndims.is_some() { + self.get_type().construct_uninitialized(generator, ctx, None) + } else { + self.get_type().construct_dyn_ndims(generator, ctx, self.load_ndims(ctx), None) + }; + + let shape = self.shape(); + clone.copy_shape_from_array(generator, ctx, shape.base_ptr(ctx, generator)); + unsafe { clone.create_data(generator, ctx) }; + clone.copy_data_from(generator, ctx, *self); + clone + } + /// Copy data from another ndarray. /// /// This ndarray and `src` is that their `np.size()` should be the same. Their shapes diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 161d59f7ce..e82a6b7d6d 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1759,14 +1759,14 @@ def run() -> int32: test_ndarray_reshape() test_ndarray_dot() - # test_ndarray_cholesky() - # test_ndarray_qr() - # test_ndarray_svd() - # test_ndarray_linalg_inv() - # test_ndarray_pinv() + test_ndarray_cholesky() + test_ndarray_qr() + test_ndarray_svd() + test_ndarray_linalg_inv() + test_ndarray_pinv() # test_ndarray_matrix_power() - # test_ndarray_det() - # test_ndarray_lu() - # test_ndarray_schur() - # test_ndarray_hessenberg() + test_ndarray_det() + test_ndarray_lu() + test_ndarray_schur() + test_ndarray_hessenberg() return 0