Compare commits

...

9 Commits

13 changed files with 1066 additions and 22 deletions

106
Cargo.lock generated
View File

@ -73,6 +73,15 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "approx"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "ascii-canvas" name = "ascii-canvas"
version = "3.0.0" version = "3.0.0"
@ -247,6 +256,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "cslice"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f8cb7306107e4b10e64994de6d3274bd08996a7c1322a27b86482392f96be0a"
[[package]] [[package]]
name = "dirs-next" name = "dirs-next"
version = "2.0.0" version = "2.0.0"
@ -521,6 +536,12 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "libm"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]] [[package]]
name = "libredox" name = "libredox"
version = "0.1.3" version = "0.1.3"
@ -531,6 +552,14 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linalg"
version = "0.1.0"
dependencies = [
"cslice",
"nalgebra",
]
[[package]] [[package]]
name = "linked-hash-map" name = "linked-hash-map"
version = "0.5.6" version = "0.5.6"
@ -659,17 +688,70 @@ version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"inkwell", "inkwell",
"linalg",
"nac3core", "nac3core",
"nac3parser", "nac3parser",
"parking_lot", "parking_lot",
] ]
[[package]]
name = "nalgebra"
version = "0.32.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4"
dependencies = [
"approx",
"num-complex",
"num-rational",
"num-traits",
"simba",
"typenum",
]
[[package]] [[package]]
name = "new_debug_unreachable" name = "new_debug_unreachable"
version = "1.0.6" version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086"
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
"libm",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.19.0" version = "1.19.0"
@ -699,6 +781,12 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]] [[package]]
name = "petgraph" name = "petgraph"
version = "0.6.5" version = "0.6.5"
@ -1070,6 +1158,18 @@ dependencies = [
"yaml-rust", "yaml-rust",
] ]
[[package]]
name = "simba"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
dependencies = [
"approx",
"num-complex",
"num-traits",
"paste",
]
[[package]] [[package]]
name = "similar" name = "similar"
version = "2.6.0" version = "2.6.0"
@ -1230,6 +1330,12 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "typenum"
version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]] [[package]]
name = "unic-char-property" name = "unic-char-property"
version = "0.9.0" version = "0.9.0"

View File

@ -161,7 +161,9 @@
clippy clippy
pre-commit pre-commit
rustfmt rustfmt
rust-analyzer
]; ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
}; };
devShells.x86_64-linux.msys2 = pkgs.mkShell { devShells.x86_64-linux.msys2 = pkgs.mkShell {
name = "nac3-dev-shell-msys2"; name = "nac3-dev-shell-msys2";

View File

@ -1865,7 +1865,7 @@ fn build_output_struct<'ctx>(
out_ptr out_ptr
} }
/// Invokes the `np_dot` using `nalgebra` crate /// Invokes the `np_dot` linalg function
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1884,7 +1884,7 @@ pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else { else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
}; };
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into()) Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
@ -1893,7 +1893,7 @@ pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_matmul` using `nalgebra` crate /// Invokes the `np_linalg_matmul` linalg function
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1913,7 +1913,7 @@ pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty) let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
else { else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -1942,7 +1942,7 @@ pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_cholesky` using `nalgebra` crate /// Invokes the `np_linalg_cholesky` linalg function
pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -1957,7 +1957,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -1984,7 +1984,7 @@ pub fn call_np_linalg_cholesky<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_qr` using `nalgebra` crate /// Invokes the `np_linalg_qr` linalg function
pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2034,7 +2034,7 @@ pub fn call_np_linalg_qr<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_svd` using `nalgebra` crate /// Invokes the `np_linalg_svd` linalg function
pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2049,7 +2049,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2089,7 +2089,7 @@ pub fn call_np_linalg_svd<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_inv` using `nalgebra` crate /// Invokes the `np_linalg_inv` linalg function
pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2104,7 +2104,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2131,7 +2131,7 @@ pub fn call_np_linalg_inv<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `np_linalg_pinv` using `nalgebra` crate /// Invokes the `np_linalg_pinv` linalg function
pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>( pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2146,7 +2146,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2174,7 +2174,7 @@ pub fn call_np_linalg_pinv<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `sp_linalg_lu` using `nalgebra` crate /// Invokes the `sp_linalg_lu` linalg function
pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2189,7 +2189,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2224,7 +2224,7 @@ pub fn call_sp_linalg_lu<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `sp_linalg_schur` using `nalgebra` crate /// Invokes the `sp_linalg_schur` linalg function
pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2239,7 +2239,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
@ -2267,7 +2267,7 @@ pub fn call_sp_linalg_schur<'ctx, G: CodeGenerator + ?Sized>(
} }
} }
/// Invokes the `sp_linalg_hessenberg` using `nalgebra` crate /// Invokes the `sp_linalg_hessenberg` linalg function
pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2282,7 +2282,7 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty); let n1_elem_ty = ctx.get_llvm_type(generator, elem_ty);
let BasicTypeEnum::FloatType(_) = n1_elem_ty else { let BasicTypeEnum::FloatType(_) = n1_elem_ty else {
unimplemented!("{FN_NAME} operates on float type NdArrays only"); unsupported_type(ctx, FN_NAME, &[x1_ty]);
}; };
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);

View File

@ -2026,3 +2026,405 @@ pub fn gen_ndarray_fill<'ctx>(
Ok(()) Ok(())
} }
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_transpose";
let (x1_ty, x1) = x1;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
let out = create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&n1,
|_, ctx, n| Ok(n.load_ndims(ctx)),
|generator, ctx, n, idx| {
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
let new_idx = ctx
.builder
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
.unwrap();
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
},
)
.unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
// Calculate transposed idx
// 2, 3 => idx = row * num_col + col = 3 + 2 = 5
// 2, 3, 4 => idx = row * (num_col*num_z) + col * (num_z) + z => 12 + 8 + 3 | 18, [4, 2] 15, ,1,2
// num_z (col + num_col * row) + z
// 4D => 2, 3, 4, 5 = idx = num_w (num_z (col + num_col*row) + z) + w = 119
// [1, 1, 2]?
// 4 (1 + 3*1) + 2 = 6
// z = 2, col = 1, row = 0
// 0,1,
// 2, 3 => idx = row * num_col + col | 2, 3, 4 => idx = row * (num_col * num_z) + col * (num_z) + num_z
// ND => idx = 1 * (dim0 + dim1 + ... dimn) + dim[-1] * (dim0 + dim1 + ... + dimn-1) + ... + dim[1] * dim0
// 6 + 12 + 6 = 24 num_z * (row*num_col + col + 1) 4*6=24
// 2, 3, 4, 5 at idx 1 should go to
// 5, 4, 3, 2
// 18 => [2, 4] dim = 4
// 0 * 4 + 2 = 2
// 4 => [1, 1] dim = 3
// 2 * 3 + 1
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
ctx.builder.build_store(rem_idx, idx).unwrap();
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n1.load_ndims(ctx), false),
|generator, ctx, _, ndim| {
let ndim_rev =
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
let ndim_rev = ctx
.builder
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
.unwrap();
let dim = unsafe {
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
};
let rem_idx_val =
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
let new_idx_val =
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
let add_component =
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
let rem_idx_val =
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
let new_idx_val =
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
x1: (Type, BasicValueEnum<'ctx>),
shape: (Type, BasicValueEnum<'ctx>),
) -> Result<BasicValueEnum<'ctx>, String> {
const FN_NAME: &str = "ndarray_reshape";
let (x1_ty, x1) = x1;
let (_, shape) = shape;
let llvm_usize = generator.get_size_type(ctx.ctx);
if let BasicValueEnum::PointerValue(n1) = x1 {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
// Check for -1 in the shapec
let ndim_ty = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
shape_list
.data()
.get(ctx, generator, &llvm_usize.const_zero(), None)
.into_int_value()
.get_type()
}
BasicValueEnum::StructValue(shape_tuple) => ctx
.builder
.build_extract_value(shape_tuple, 0, "")
.unwrap()
.into_int_value()
.get_type(),
BasicValueEnum::IntValue(shape_int) => shape_int.get_type(),
_ => unreachable!(),
};
let n_sz = ctx
.builder
.build_cast(inkwell::values::InstructionOpcode::Trunc, n_sz, ndim_ty, "")
.unwrap()
.into_int_value();
let acc = generator.gen_var_alloc(ctx, ndim_ty.into(), None)?;
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
let out = match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(shape_list.load_size(ctx, None), false),
|generator, ctx, _, idx| {
let ele =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
ele,
ndim_ty.const_zero(),
"",
)
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_neg_value =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_neg_value = ctx
.builder
.build_int_add(
num_neg_value,
llvm_usize.const_int(1, false),
"",
)
.unwrap();
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
Ok(None)
},
|_, ctx| {
let acc_value =
ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_value =
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
ctx.builder.build_store(acc, acc_value).unwrap();
Ok(None)
},
)?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
Ok(gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
dim,
ndim_ty.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
let ndims = shape_tuple.get_type().count_fields();
let acc = ctx.builder.build_alloca(ndim_ty, "").unwrap();
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "")
.unwrap())
},
|_, ctx| -> Result<Option<IntValue>, String> {
let num_negs =
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
let num_negs = ctx
.builder
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
.unwrap();
ctx.builder.build_store(num_neg, num_negs).unwrap();
Ok(None)
},
|_, ctx| {
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, "")
.unwrap()
.into_int_value();
let dim = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(IntPredicate::SLT, dim, ndim_ty.const_zero(), "")
.unwrap())
},
|_, _| Ok(Some(rem)),
|_, _| Ok(Some(dim)),
)?
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
let shape_int = gen_if_else_expr_callback(
generator,
ctx,
|_, ctx| {
Ok(ctx
.builder
.build_int_compare(
IntPredicate::SLT,
shape_int,
ndim_ty.const_zero(),
"",
)
.unwrap())
},
|_, _| Ok(Some(n_sz)),
|_, _| Ok(Some(shape_int)),
)?
.unwrap()
.into_int_value();
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
}
.unwrap();
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
.unwrap(),
"0:ValueError",
"can only specify one unknown dimension",
[None, None, None],
ctx.current_loc,
);
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
let out_sz = ctx
.builder
.build_cast(inkwell::values::InstructionOpcode::Trunc, out_sz, ndim_ty, "")
.unwrap()
.into_int_value();
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {} into provided shape of size {}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);
gen_for_callback_incrementing(
generator,
ctx,
None,
llvm_usize.const_zero(),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
Ok(())
},
llvm_usize.const_int(1, false),
)?;
Ok(out.as_base_value().into())
} else {
unreachable!(
"{FN_NAME}() not supported for '{}'",
format!("'{}'", ctx.unifier.stringify(x1_ty))
)
}
}

View File

@ -557,6 +557,10 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpHypot | PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
self.build_np_sp_ndarray_1ary_function(prim)
}
PrimDef::FunNpDot PrimDef::FunNpDot
| PrimDef::FunNpLinalgMatmul | PrimDef::FunNpLinalgMatmul
| PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgCholesky
@ -1885,6 +1889,66 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build 1-ary numpy/scipy functions that take in an ndarray and return a value of the same type as the input.
fn build_np_sp_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
let elem_type = self.unifier.get_fresh_var(Some("R".into()), None);
let ndarray_type = make_ndarray_ty(self.unifier, self.primitives, Some(elem_type.ty), None);
let ndarray_ty =
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("T".into()), None);
let var_map = into_var_map([elem_type, ndarray_ty]);
match prim {
PrimDef::FunNpTranspose => create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x")],
Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
}),
),
PrimDef::FunNpReshape => {
// Return type can have differend ndims
let ret_ty =
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("U".into()), None);
let var_map = var_map
.clone()
.into_iter()
.chain(once((ret_ty.id, ret_ty.ty)))
.chain(once((
self.ndarray_factory_fn_shape_arg_tvar.id,
self.ndarray_factory_fn_shape_arg_tvar.ty,
)))
.collect::<IndexMap<_, _>>();
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
ret_ty.ty,
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val =
args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
)
}
_ => unreachable!(),
}
}
/// Build `np_linalg` and `sp_linalg` functions /// Build `np_linalg` and `sp_linalg` functions
/// ///
/// The input to these functions must be floating point `NDArray` /// The input to these functions must be floating point `NDArray`

View File

@ -99,7 +99,10 @@ pub enum PrimDef {
FunNpLdExp, FunNpLdExp,
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
FunNpTranspose,
FunNpReshape,
// Linalg functions
FunNpDot, FunNpDot,
FunNpLinalgMatmul, FunNpLinalgMatmul,
FunNpLinalgCholesky, FunNpLinalgCholesky,
@ -281,6 +284,10 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunNpTranspose => fun("np_transpose", None),
PrimDef::FunNpReshape => fun("np_reshape", None),
// Linalg functions
PrimDef::FunNpDot => fun("np_dot", None), PrimDef::FunNpDot => fun("np_dot", None),
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None), PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None), PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),

View File

@ -8,6 +8,7 @@ edition = "2021"
parking_lot = "0.12" parking_lot = "0.12"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
nac3core = { path = "../nac3core" } nac3core = { path = "../nac3core" }
linalg = { path = "./demo/linalg" }
[dependencies.clap] [dependencies.clap]
version = "4.5" version = "4.5"

View File

@ -218,6 +218,8 @@ def patch(module):
module.np_ldexp = np.ldexp module.np_ldexp = np.ldexp
module.np_hypot = np.hypot module.np_hypot = np.hypot
module.np_nextafter = np.nextafter module.np_nextafter = np.nextafter
module.np_transpose = np.transpose
module.np_reshape = np.reshape
# SciPy Math functions # SciPy Math functions
module.sp_spec_erf = special.erf module.sp_spec_erf = special.erf

View File

@ -0,0 +1,11 @@
[package]
name = "linalg"
version = "0.1.0"
edition = "2021"
[lib]
crate-type = ["staticlib"]
[dependencies]
nalgebra = {version = "0.32.6", default-features = false, features = ["libm", "alloc"]}
cslice = "0.3.0"

View File

@ -0,0 +1,413 @@
// Uses `nalgebra` crate to invoke `np_linalg` and `sp_linalg` functions
// When converting between `nalgebra::Matrix` and `NDArray` following considerations are necessary
//
// * Both `nalgebra::Matrix` and `NDArray` require their content to be stored in row-major order
// * `NDArray` data pointer can be directly read and converted to `nalgebra::Matrix` (row and column number must be known)
// * `nalgebra::Matrix::as_slice` returns the content of matrix in column-major order and initial data needs to be transposed before storing it in `NDArray` data pointer
use core::slice;
use nalgebra::DMatrix;
fn report_error(
error_name: &str,
fn_name: &str,
file_name: &str,
line_num: u32,
col_num: u32,
err_msg: &str,
) -> ! {
panic!(
"Exception {} from {} in {}:{}:{}, message: {}",
error_name, fn_name, file_name, line_num, col_num, err_msg
);
}
pub struct InputMatrix {
pub ndims: usize,
pub dims: *const usize,
pub data: *mut f64,
}
impl InputMatrix {
fn get_dims(&mut self) -> Vec<usize> {
let dims = unsafe { slice::from_raw_parts(self.dims, self.ndims) };
dims.to_vec()
}
}
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
if !(mat1.ndims == 1 && mat2.ndims == 1) {
let err_msg = format!(
"expected 1D Vector Input, but received {}-D and {}-D input",
mat1.ndims, mat2.ndims
);
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[0] != dim2[0] {
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
}
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
matrix1.dot(&matrix2)
}
/// # Safety
///
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_matmul(
mat1: *mut InputMatrix,
mat2: *mut InputMatrix,
out: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let mat2 = mat2.as_mut().unwrap();
let out = out.as_mut().unwrap();
if !(mat1.ndims == 2 && mat2.ndims == 2) {
let err_msg = format!(
"expected 2D Vector Input, but received {}-D and {}-D input",
mat1.ndims, mat2.ndims
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let dim2 = (*mat2).get_dims();
if dim1[1] != dim2[0] {
let err_msg = format!(
"shapes ({},{}) and ({},{}) not aligned: {} (dim 1) != {} (dim 0)",
dim1[0], dim1[1], dim2[0], dim2[1], dim1[1], dim2[0]
);
report_error("ValueError", "np_matmul", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0] * dim2[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let matrix2 = DMatrix::from_row_slice(dim2[0], dim2[1], data_slice2);
let mut result = DMatrix::<f64>::zeros(outdim[0], outdim[1]);
matrix1.mul_to(&matrix2, &mut result);
out_slice.copy_from_slice(result.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_cholesky(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix1.cholesky();
match result {
Some(res) => {
out_slice.copy_from_slice(res.unpack().transpose().as_slice());
}
None => {
report_error(
"LinAlgError",
"np_linalg_cholesky",
file!(),
line!(),
column!(),
"Matrix is not positive definite",
);
}
};
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_qr(
mat1: *mut InputMatrix,
out_q: *mut InputMatrix,
out_r: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
let out_r = out_r.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "np_linalg_cholesky", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outq_dim = (*out_q).get_dims();
let outr_dim = (*out_r).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, outq_dim[0] * outq_dim[1]) };
let out_r_slice = unsafe { slice::from_raw_parts_mut(out_r.data, outr_dim[0] * outr_dim[1]) };
// Refer to https://github.com/dimforge/nalgebra/issues/735
let matrix1 = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let res = matrix1.qr();
let (q, r) = res.unpack();
// Uses different algo need to match numpy
out_q_slice.copy_from_slice(q.transpose().as_slice());
out_r_slice.copy_from_slice(r.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_svd(
mat1: *mut InputMatrix,
outu: *mut InputMatrix,
outs: *mut InputMatrix,
outvh: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let outu = outu.as_mut().unwrap();
let outs = outs.as_mut().unwrap();
let outvh = outvh.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "np_linalg_svd", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outu_dim = (*outu).get_dims();
let outs_dim = (*outs).get_dims();
let outvh_dim = (*outvh).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(outu.data, outu_dim[0] * outu_dim[1]) };
let out_s_slice = unsafe { slice::from_raw_parts_mut(outs.data, outs_dim[0]) };
let out_vh_slice =
unsafe { slice::from_raw_parts_mut(outvh.data, outvh_dim[0] * outvh_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let result = matrix.svd(true, true);
out_u_slice.copy_from_slice(result.u.unwrap().transpose().as_slice());
out_s_slice.copy_from_slice(result.singular_values.as_slice());
out_vh_slice.copy_from_slice(result.v_t.unwrap().transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_inv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_inv", file!(), line!(), column!(), &err_msg);
}
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
if !matrix.is_invertible() {
report_error(
"LinAlgError",
"np_linalg_inv",
file!(),
line!(),
column!(),
"no inverse for Singular Matrix",
);
}
let inv = matrix.try_inverse().unwrap();
out_slice.copy_from_slice(inv.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn np_linalg_pinv(mat1: *mut InputMatrix, out: *mut InputMatrix) {
let mat1 = mat1.as_mut().unwrap();
let out = out.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "np_linalg_pinv", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outdim = out.get_dims();
let out_slice = unsafe { slice::from_raw_parts_mut(out.data, outdim[0] * outdim[1]) };
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let svd = matrix.svd(true, true);
let inv = svd.pseudo_inverse(1e-15);
match inv {
Ok(m) => {
out_slice.copy_from_slice(m.transpose().as_slice());
}
Err(err_msg) => {
report_error("LinAlgError", "np_linalg_pinv", file!(), line!(), column!(), err_msg);
}
}
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_lu(
mat1: *mut InputMatrix,
out_l: *mut InputMatrix,
out_u: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_l = out_l.as_mut().unwrap();
let out_u = out_u.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "sp_linalg_lu", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
let outl_dim = (*out_l).get_dims();
let outu_dim = (*out_u).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_l_slice = unsafe { slice::from_raw_parts_mut(out_l.data, outl_dim[0] * outl_dim[1]) };
let out_u_slice = unsafe { slice::from_raw_parts_mut(out_u.data, outu_dim[0] * outu_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (_, l, u) = matrix.lu().unpack();
out_l_slice.copy_from_slice(l.transpose().as_slice());
out_u_slice.copy_from_slice(u.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_schur(
mat1: *mut InputMatrix,
out_t: *mut InputMatrix,
out_z: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_t = out_t.as_mut().unwrap();
let out_z = out_z.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "sp_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {0} != {1}", dim1[0], dim1[1]);
report_error("LinAlgError", "np_linalg_schur", file!(), line!(), column!(), &err_msg);
}
let out_t_dim = (*out_t).get_dims();
let out_z_dim = (*out_z).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_t_slice = unsafe { slice::from_raw_parts_mut(out_t.data, out_t_dim[0] * out_t_dim[1]) };
let out_z_slice = unsafe { slice::from_raw_parts_mut(out_z.data, out_z_dim[0] * out_z_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (z, t) = matrix.schur().unpack();
out_t_slice.copy_from_slice(t.transpose().as_slice());
out_z_slice.copy_from_slice(z.transpose().as_slice());
}
/// # Safety
///
/// `mat1` should point to a valid 2DArray of `f64` floats in row-major order
#[no_mangle]
pub unsafe extern "C" fn sp_linalg_hessenberg(
mat1: *mut InputMatrix,
out_h: *mut InputMatrix,
out_q: *mut InputMatrix,
) {
let mat1 = mat1.as_mut().unwrap();
let out_h = out_h.as_mut().unwrap();
let out_q = out_q.as_mut().unwrap();
if mat1.ndims != 2 {
let err_msg = format!("expected 2D Vector Input, but received {}-D input", mat1.ndims);
report_error("ValueError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let dim1 = (*mat1).get_dims();
if dim1[0] != dim1[1] {
let err_msg =
format!("last 2 dimensions of the array must be square: {} != {}", dim1[0], dim1[1]);
report_error("LinAlgError", "sp_linalg_hessenberg", file!(), line!(), column!(), &err_msg);
}
let out_h_dim = (*out_h).get_dims();
let out_q_dim = (*out_q).get_dims();
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0] * dim1[1]) };
let out_h_slice = unsafe { slice::from_raw_parts_mut(out_h.data, out_h_dim[0] * out_h_dim[1]) };
let out_q_slice = unsafe { slice::from_raw_parts_mut(out_q.data, out_q_dim[0] * out_q_dim[1]) };
let matrix = DMatrix::from_row_slice(dim1[0], dim1[1], data_slice1);
let (q, h) = matrix.hessenberg().unpack();
out_h_slice.copy_from_slice(h.transpose().as_slice());
out_q_slice.copy_from_slice(q.transpose().as_slice());
}

View File

@ -42,11 +42,14 @@ done
if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then if [ -n "$debug" ] && [ -e ../../target/debug/nac3standalone ]; then
nac3standalone=../../target/debug/nac3standalone nac3standalone=../../target/debug/nac3standalone
linalg=../../target/debug/deps/liblinalg-?*.a
elif [ -e ../../target/release/nac3standalone ]; then elif [ -e ../../target/release/nac3standalone ]; then
nac3standalone=../../target/release/nac3standalone nac3standalone=../../target/release/nac3standalone
linalg=../../target/release/deps/liblinalg-?*.a
else else
# used by Nix builds # used by Nix builds
nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone nac3standalone=../../target/x86_64-unknown-linux-gnu/release/nac3standalone
linalg=../../target/x86_64-unknown-linux-gnu/release/deps/liblinalg-?*.a
fi fi
rm -f ./*.o ./*.bc demo rm -f ./*.o ./*.bc demo
@ -55,14 +58,20 @@ if [ -z "$i386" ]; then
$nac3standalone "${nac3args[@]}" $nac3standalone "${nac3args[@]}"
clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c clang -c -std=gnu11 -Wall -Wextra -O3 -o demo.o demo.c
clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o clang -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg
else else
# Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations # Enable SSE2 to avoid rounding errors with X87's 80-bit fp precision computations
$nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}" $nac3standalone --triple i386-pc-linux-gnu --target-features +sse2 "${nac3args[@]}"
# Compile linalg crate to provide functions compatible with i386 architecture
if [ ! -f ../../target/i686-unknown-linux-gnu/release/liblinalg.a ]; then
cd linalg && nix-shell -p rustup --command "RUSTFLAGS=\"-C target-cpu=i386 -C target-feature=+sse2\" cargo build -q --release --target=i686-unknown-linux-gnu" && cd ..
fi
linalg=../../target/i686-unknown-linux-gnu/release/liblinalg.a
clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c clang -m32 -c -std=gnu11 -Wall -Wextra -O3 -msse2 -o demo.o demo.c
clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o clang -m32 -lm -Wl,--no-warn-search-mismatch -o demo module.o demo.o $linalg
fi fi
if [ -z "$outfile" ]; then if [ -z "$outfile" ]; then

View File

@ -68,6 +68,12 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
for c in range(len(n[r])): for c in range(len(n[r])):
output_float64(n[r][c]) output_float64(n[r][c])
def output_ndarray_float_3(n: ndarray[float, Literal[3]]):
for r in range(len(n)):
for c in range(len(n[r])):
for z in range(len(n[r][c])):
output_float64(n[r][c][z])
def consume_ndarray_1(n: ndarray[float, Literal[1]]): def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass pass
@ -1429,6 +1435,24 @@ def test_ndarray_nextafter_broadcast_rhs_scalar():
output_ndarray_float_2(nextafter_x_zeros) output_ndarray_float_2(nextafter_x_zeros)
output_ndarray_float_2(nextafter_x_ones) output_ndarray_float_2(nextafter_x_ones)
def test_ndarray_transpose():
# x: ndarray[float, 3] = np_array([[[1., 2.], [3., 4.], [5., 6.]]])
x: ndarray[float, 1] = np_array([1., 2., 3.])
y = np_transpose(x)
output_ndarray_float_1(x)
output_ndarray_float_1(y)
def test_ndarray_reshape():
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x: ndarray[float, 4] = np_reshape(w, (1, 2, 1, -1))
y: ndarray[float, 2] = np_reshape(x, [2, -1])
z: ndarray[float, 1] = np_reshape(w, 10)
output_ndarray_float_1(w)
output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_dot(): def test_ndarray_dot():
x: ndarray[float, 1] = np_array([5.0, 1.0]) x: ndarray[float, 1] = np_array([5.0, 1.0])
y: ndarray[float, 1] = np_array([5.0, 1.0]) y: ndarray[float, 1] = np_array([5.0, 1.0])
@ -1705,6 +1729,8 @@ def run() -> int32:
test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose()
test_ndarray_reshape()
test_ndarray_dot() test_ndarray_dot()
test_ndarray_linalg_matmul() test_ndarray_linalg_matmul()

1
pyo3_output/nac3artiq.so Symbolic link
View File

@ -0,0 +1 @@
../target/debug/libnac3artiq.so