forked from M-Labs/nac3
[core] codegen: Cleanup builtin_fns.rs
- Unpack tuples directly in function argument - Replace Vec parameters with slices - Replace unwrap-transform with map-unwrap
This commit is contained in:
parent
8e2b50df21
commit
ad67a99c8f
@ -43,11 +43,10 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) -
|
||||
pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(arg_ty, arg): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<IntValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let range_ty = ctx.primitives.range;
|
||||
let (arg_ty, arg) = n;
|
||||
|
||||
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
|
||||
let arg = RangeValue::from_pointer_value(arg.into_pointer_value(), Some("range"));
|
||||
@ -105,12 +104,11 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||
@ -168,13 +166,11 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => {
|
||||
debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,]
|
||||
@ -231,13 +227,11 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||
@ -310,13 +304,11 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32) => {
|
||||
debug_assert!([ctx.primitives.bool, ctx.primitives.int32, ctx.primitives.uint32,]
|
||||
@ -378,13 +370,11 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8 | 32 | 64) => {
|
||||
debug_assert!([
|
||||
@ -445,14 +435,12 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
ret_elem_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "round";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty).into_int_type();
|
||||
|
||||
Ok(match n {
|
||||
@ -492,14 +480,12 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_round";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::FloatValue(n) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.float));
|
||||
@ -533,14 +519,12 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "bool";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
|
||||
Ok(match n {
|
||||
BasicValueEnum::IntValue(n) if matches!(n.get_type().get_bit_width(), 1 | 8) => {
|
||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||
@ -603,14 +587,12 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
ret_elem_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "floor";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty);
|
||||
|
||||
Ok(match n {
|
||||
@ -654,14 +636,12 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
ret_elem_ty: Type,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "ceil";
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (n_ty, n) = n;
|
||||
let llvm_ret_elem_ty = ctx.get_llvm_abi_type(generator, ret_elem_ty);
|
||||
|
||||
Ok(match n {
|
||||
@ -704,14 +684,11 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Invokes the `min` builtin function.
|
||||
pub fn call_min<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
m: (Type, BasicValueEnum<'ctx>),
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(m_ty, m): (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
const FN_NAME: &str = "min";
|
||||
|
||||
let (m_ty, m) = m;
|
||||
let (n_ty, n) = n;
|
||||
|
||||
let common_ty = if ctx.unifier.unioned(m_ty, n_ty) {
|
||||
m_ty
|
||||
} else {
|
||||
@ -754,14 +731,11 @@ pub fn call_min<'ctx>(
|
||||
pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_minimum";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None };
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
@ -836,14 +810,11 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Invokes the `max` builtin function.
|
||||
pub fn call_max<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
m: (Type, BasicValueEnum<'ctx>),
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
(m_ty, m): (Type, BasicValueEnum<'ctx>),
|
||||
(n_ty, n): (Type, BasicValueEnum<'ctx>),
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
const FN_NAME: &str = "max";
|
||||
|
||||
let (m_ty, m) = m;
|
||||
let (n_ty, n) = n;
|
||||
|
||||
let common_ty = if ctx.unifier.unioned(m_ty, n_ty) {
|
||||
m_ty
|
||||
} else {
|
||||
@ -887,7 +858,7 @@ pub fn call_max<'ctx>(
|
||||
pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
a: (Type, BasicValueEnum<'ctx>),
|
||||
(a_ty, a): (Type, BasicValueEnum<'ctx>),
|
||||
fn_name: &str,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
debug_assert!(["np_argmin", "np_argmax", "np_max", "np_min"].iter().any(|f| *f == fn_name));
|
||||
@ -895,7 +866,6 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let llvm_int64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (a_ty, a) = a;
|
||||
Ok(match a {
|
||||
BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => {
|
||||
debug_assert!([
|
||||
@ -1016,14 +986,11 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_maximum";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
let common_ty = if ctx.unifier.unioned(x1_ty, x2_ty) { Some(x1_ty) } else { None };
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
@ -1163,6 +1130,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
|
||||
n: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "abs";
|
||||
|
||||
helper_call_numpy_unary_elementwise(
|
||||
generator,
|
||||
ctx,
|
||||
@ -1473,14 +1441,11 @@ create_helper_call_numpy_unary_elementwise_float_to_float!(
|
||||
pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_arctan2";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
@ -1540,14 +1505,11 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_copysign";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
@ -1607,14 +1569,11 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_fmax";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
@ -1674,14 +1633,11 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_fmin";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
@ -1741,14 +1697,11 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_ldexp";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::IntValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
@ -1797,14 +1750,11 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_hypot";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
@ -1864,14 +1814,11 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_nextafter";
|
||||
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
Ok(match (x1, x2) {
|
||||
(BasicValueEnum::FloatValue(x1), BasicValueEnum::FloatValue(x2)) => {
|
||||
debug_assert!(ctx.unifier.unioned(x1_ty, ctx.primitives.float));
|
||||
@ -1930,14 +1877,13 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|
||||
/// Allocates a struct with the fields specified by `out_matrices` and returns a pointer to it
|
||||
fn build_output_struct<'ctx>(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
out_matrices: Vec<BasicValueEnum<'ctx>>,
|
||||
out_matrices: &[BasicValueEnum<'ctx>],
|
||||
) -> PointerValue<'ctx> {
|
||||
let field_ty =
|
||||
out_matrices.iter().map(BasicValueEnum::get_type).collect::<Vec<BasicTypeEnum>>();
|
||||
let field_ty = out_matrices.iter().map(BasicValueEnum::get_type).collect_vec();
|
||||
let out_ty = ctx.ctx.struct_type(&field_ty, false);
|
||||
let out_ptr = ctx.builder.build_alloca(out_ty, "").unwrap();
|
||||
|
||||
for (i, v) in out_matrices.into_iter().enumerate() {
|
||||
for (i, v) in out_matrices.iter().enumerate() {
|
||||
unsafe {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
@ -1950,7 +1896,7 @@ fn build_output_struct<'ctx>(
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(ptr, v).unwrap();
|
||||
ctx.builder.build_store(ptr, *v).unwrap();
|
||||
}
|
||||
}
|
||||
out_ptr
|
||||
@ -1960,10 +1906,10 @@ fn build_output_struct<'ctx>(
|
||||
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_cholesky";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -1987,9 +1933,9 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_cholesky(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
@ -2002,10 +1948,10 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_qr";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -2030,17 +1976,17 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
let out_r = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_qr(ctx, x1, out_q, out_r, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_q, out_r]);
|
||||
let out_ptr = build_output_struct(ctx, &[out_q, out_r]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "QR_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
@ -2052,10 +1998,10 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_svd";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -2081,21 +2027,21 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
let out_s = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
let out_vh = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_svd(ctx, x1, out_u, out_s, out_vh, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_u, out_s, out_vh]);
|
||||
let out_ptr = build_output_struct(ctx, &[out_u, out_s, out_vh]);
|
||||
|
||||
Ok(ctx.builder.build_load(out_ptr, "SVD_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
@ -2107,10 +2053,10 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_inv";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -2134,9 +2080,9 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_inv(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
@ -2149,10 +2095,10 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_pinv";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -2177,9 +2123,9 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim1, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_pinv(ctx, x1, out, None);
|
||||
Ok(out)
|
||||
@ -2192,10 +2138,10 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_lu";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -2221,17 +2167,17 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
let k = llvm_intrinsics::call_int_smin(ctx, dim0, dim1, None);
|
||||
|
||||
let out_l = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, k])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
let out_u = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[k, dim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_sp_linalg_lu(ctx, x1, out_l, out_u, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_l, out_u]);
|
||||
let out_ptr = build_output_struct(ctx, &[out_l, out_u]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "LU_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
@ -2242,12 +2188,11 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
x2: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
(x2_ty, x2): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (x2_ty, x2) = x2;
|
||||
|
||||
let x2 = call_float(generator, ctx, (x2_ty, x2)).unwrap();
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
@ -2290,11 +2235,12 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
};
|
||||
|
||||
let out = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[outdim0, outdim1])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_matrix_power(ctx, x1, n2_array, out, None);
|
||||
|
||||
Ok(out)
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
||||
@ -2305,10 +2251,9 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_linalg_matrix_power";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
if let BasicValueEnum::PointerValue(_) = x1 {
|
||||
@ -2327,7 +2272,9 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_np_linalg_det(ctx, x1, out.as_base_value().as_basic_value_enum(), None);
|
||||
|
||||
let res =
|
||||
unsafe { out.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||
Ok(res)
|
||||
@ -2340,10 +2287,10 @@ pub fn call_np_linalg_det<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_schur";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -2362,17 +2309,17 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.into_int_value()
|
||||
};
|
||||
let out_t = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
let out_z = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_sp_linalg_schur(ctx, x1, out_t, out_z, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_t, out_z]);
|
||||
let out_ptr = build_output_struct(ctx, &[out_t, out_z]);
|
||||
Ok(ctx.builder.build_load(out_ptr, "Schur_Factorization_result").map(Into::into).unwrap())
|
||||
} else {
|
||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
@ -2383,10 +2330,10 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
|
||||
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
(x1_ty, x1): (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "sp_linalg_hessenberg";
|
||||
let (x1_ty, x1) = x1;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
@ -2405,16 +2352,17 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||
.into_int_value()
|
||||
};
|
||||
let out_h = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
let out_q = numpy::create_ndarray_const_shape(generator, ctx, elem_ty, &[dim0, dim0])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
.map(NDArrayValue::into)
|
||||
.map(PointerValue::into)
|
||||
.unwrap();
|
||||
|
||||
extern_fns::call_sp_linalg_hessenberg(ctx, x1, out_h, out_q, None);
|
||||
|
||||
let out_ptr = build_output_struct(ctx, vec![out_h, out_q]);
|
||||
let out_ptr = build_output_struct(ctx, &[out_h, out_q]);
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_load(out_ptr, "Hessenberg_decomposition_result")
|
||||
|
Loading…
Reference in New Issue
Block a user