Merge pull request #839 from dimforge/dev

Release v0.25.0
This commit is contained in:
Sébastien Crozet 2021-03-01 14:34:04 +01:00 committed by GitHub
commit 39ef8b43cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
152 changed files with 16011 additions and 1656 deletions

View File

@ -1,119 +0,0 @@
version: 2.1
executors:
rust-nightly-executor:
docker:
- image: rustlang/rust:nightly
rust-executor:
docker:
- image: rust:latest
jobs:
check-fmt:
executor: rust-executor
steps:
- checkout
- run:
name: install rustfmt
command: rustup component add rustfmt
- run:
name: check formatting
command: cargo fmt -- --check
clippy:
executor: rust-executor
steps:
- checkout
- run:
name: install clippy
command: rustup component add clippy
- run:
name: clippy
command: cargo clippy
build-native:
executor: rust-executor
steps:
- checkout
- run: apt-get update
- run: apt-get install -y cmake gfortran libblas-dev liblapack-dev
- run:
name: build --no-default-feature
command: cargo build --no-default-features;
- run:
name: build (default features)
command: cargo build;
- run:
name: build --all-features
command: cargo build --all-features
- run:
name: build nalgebra-glm
command: cargo build -p nalgebra-glm --all-features
- run:
name: build nalgebra-lapack
command: cd nalgebra-lapack; cargo build
test-native:
executor: rust-executor
steps:
- checkout
- run:
name: test
command: cargo test --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm
- run:
name: test nalgebra-glm
command: cargo test -p nalgebra-glm --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm
build-wasm:
executor: rust-executor
steps:
- checkout
- run:
name: install cargo-web
command: cargo install -f cargo-web;
- run:
name: build --all-features
command: cargo web build --verbose --target wasm32-unknown-unknown;
- run:
name: build nalgebra-glm
command: cargo build -p nalgebra-glm --all-features
build-no-std:
executor: rust-nightly-executor
steps:
- checkout
- run:
name: install xargo
command: cp .circleci/Xargo.toml .; rustup component add rust-src; cargo install -f xargo;
- run:
name: build
command: xargo build --verbose --no-default-features --target=x86_64-unknown-linux-gnu;
- run:
name: build --features alloc
command: xargo build --verbose --no-default-features --features alloc --target=x86_64-unknown-linux-gnu;
build-nightly:
executor: rust-nightly-executor
steps:
- checkout
- run:
name: build --all-features
command: cargo build --all-features
workflows:
version: 2
build:
jobs:
- check-fmt
- clippy
- build-native:
requires:
- check-fmt
- build-wasm:
requires:
- check-fmt
- build-no-std:
requires:
- check-fmt
- build-nightly:
requires:
- check-fmt
- test-native:
requires:
- build-native

96
.github/workflows/nalgebra-ci-build.yml vendored Normal file
View File

@ -0,0 +1,96 @@
name: nalgebra CI build
on:
push:
branches: [ dev, master ]
pull_request:
branches: [ dev, master ]
env:
CARGO_TERM_COLOR: always
jobs:
check-fmt:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Check formatting
run: cargo fmt -- --check
clippy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install clippy
run: rustup component add clippy
- name: Run clippy
run: cargo clippy
build-nalgebra:
runs-on: ubuntu-latest
# env:
# RUSTFLAGS: -D warnings
steps:
- uses: actions/checkout@v2
- name: Build --no-default-feature
run: cargo build --no-default-features;
- name: Build (default features)
run: cargo build;
- name: Build --all-features
run: cargo build --all-features;
- name: Build nalgebra-glm
run: cargo build -p nalgebra-glm --all-features;
- name: Build nalgebra-lapack
run: cd nalgebra-lapack; cargo build;
- name: Build nalgebra-sparse
run: cd nalgebra-sparse; cargo build;
test-nalgebra:
runs-on: ubuntu-latest
# env:
# RUSTFLAGS: -D warnings
steps:
- uses: actions/checkout@v2
- name: test
run: cargo test --features arbitrary --features serde-serialize,abomonation-serialize,sparse,debug,io,compare,libm,proptest-support,slow-tests;
test-nalgebra-glm:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: test nalgebra-glm
run: cargo test -p nalgebra-glm --features arbitrary,serde-serialize,abomonation-serialize,sparse,debug,io,compare,libm,proptest-support,slow-tests;
test-nalgebra-sparse:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: test nalgebra-sparse
# Manifest-path is necessary because cargo otherwise won't correctly forward features
# We increase number of proptest cases to hopefully catch more potential bugs
run: PROPTEST_CASES=10000 cargo test --manifest-path=nalgebra-sparse/Cargo.toml --features compare,proptest-support
- name: test nalgebra-sparse (slow tests)
# Unfortunately, the "slow-tests" take so much time that we need to run them with --release
run: PROPTEST_CASES=10000 cargo test --release --manifest-path=nalgebra-sparse/Cargo.toml --features compare,proptest-support,slow-tests slow
build-wasm:
runs-on: ubuntu-latest
# env:
# RUSTFLAGS: -D warnings
steps:
- uses: actions/checkout@v2
- run: rustup target add wasm32-unknown-unknown
- name: build nalgebra
run: cargo build --verbose --target wasm32-unknown-unknown;
- name: build nalgebra-glm
run: cargo build -p nalgebra-glm --verbose --target wasm32-unknown-unknown;
build-no-std:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Install latest nightly
uses: actions-rs/toolchain@v1
with:
toolchain: nightly
override: true
components: rustfmt
- name: install xargo
run: cp .github/Xargo.toml .; rustup component add rust-src; cargo install -f xargo;
- name: build
run: xargo build --verbose --no-default-features --target=x86_64-unknown-linux-gnu;
- name: build --feature alloc
run: xargo build --verbose --no-default-features --features alloc --target=x86_64-unknown-linux-gnu;

1
.gitignore vendored
View File

@ -10,3 +10,4 @@ Cargo.lock
site/ site/
.vscode/ .vscode/
.idea/ .idea/
proptest-regressions

View File

@ -4,6 +4,27 @@ documented here.
This project adheres to [Semantic Versioning](https://semver.org/). This project adheres to [Semantic Versioning](https://semver.org/).
## [0.25.0]
This updates all the dependencies of nalgebra to their latest version, including:
- rand 0.8
- proptest 1.0
- simba 0.4
### New crate!
Alongside this release of `nalgebra`, we are releasing `nalgebra-sparse`: a crate dedicated to sparse matrix
computation with `nalgebra`. The `sparse` module of `nalgebra`itself still exists for backward compatibility
but it will be deprecated soon in favor of the `nalgebra-sparse` crate.
### Added
* Add `UnitDualQuaternion`, a dual-quaternion with unit magnitude which can be used as an isometry transformation.
* Add `UDU::new()` and `matrix.udu()` to compute the UDU factorization of a matrix.
* Add `ColPivQR::new()` and `matrix.col_piv_qr()` to compute the QR decomposition with column pivoting of a matrix.
* Add `from_basis_unchecked` to all the rotation types. This builds a rotation from a set of basis vectors (representing the columns of the corresponding rotation matrix).
* Add `Matrix::cap_magnitude` to cap the magnitude of a vector.
* Add `UnitQuaternion::append_axisangle_linearized` to approximately append a rotation represented as an axis-angle to a rotation represented as an unit quaternion.
* Mark the iterators on matrix components as `DoubleEndedIter`.
* Re-export `simba::simd::SimdValue` at the root of the `nalgebra` crate.
## [0.24.0] ## [0.24.0]
### Added ### Added
@ -67,7 +88,7 @@ In this release, we are no longer relying on traits from the __alga__ crate for
Instead, we use traits from the new [simba](https://crates.io/crates/simba) crate which are both Instead, we use traits from the new [simba](https://crates.io/crates/simba) crate which are both
simpler, and allow for significant optimizations like AoSoA SIMD. simpler, and allow for significant optimizations like AoSoA SIMD.
Refer to the [monthly Rustsim blogpost](https://www.rustsim.org/blog/2020/04/01/this-month-in-rustsim/) Refer to the [monthly dimforge blogpost](https://www.dimforge.org/blog/2020/04/01/this-month-in-dimforge/)
for details about this switch and its benefits. for details about this switch and its benefits.
### Added ### Added

View File

@ -1,28 +1,29 @@
[package] [package]
name = "nalgebra" name = "nalgebra"
version = "0.24.0" version = "0.25.0"
authors = [ "Sébastien Crozet <developer@crozet.re>" ] authors = [ "Sébastien Crozet <developer@crozet.re>" ]
description = "Linear algebra library with transformations and statically-sized or dynamically-sized matrices." description = "General-purpose linear algebra library with transformations and statically-sized or dynamically-sized matrices."
documentation = "https://nalgebra.org/rustdoc/nalgebra/index.html" documentation = "https://www.nalgebra.org/docs"
homepage = "https://nalgebra.org" homepage = "https://nalgebra.org"
repository = "https://github.com/rustsim/nalgebra" repository = "https://github.com/dimforge/nalgebra"
readme = "README.md" readme = "README.md"
categories = [ "science" ] categories = [ "science", "mathematics", "wasm", "no-std" ]
keywords = [ "linear", "algebra", "matrix", "vector", "math" ] keywords = [ "linear", "algebra", "matrix", "vector", "math" ]
license = "Apache-2.0" license = "BSD-3-Clause"
edition = "2018" edition = "2018"
exclude = ["/ci/*", "/.travis.yml", "/Makefile"] exclude = ["/ci/*", "/.travis.yml", "/Makefile"]
[badges]
maintenance = { status = "actively-developed" }
[lib] [lib]
name = "nalgebra" name = "nalgebra"
path = "src/lib.rs" path = "src/lib.rs"
[features] [features]
default = [ "std" ] default = [ "std" ]
std = [ "matrixmultiply", "rand/std", "rand_distr", "simba/std" ] std = [ "matrixmultiply", "rand/std", "rand/std_rng", "rand_distr", "simba/std" ]
stdweb = [ "rand/stdweb" ]
arbitrary = [ "quickcheck" ] arbitrary = [ "quickcheck" ]
serde-serialize = [ "serde", "num-complex/serde" ] serde-serialize = [ "serde", "num-complex/serde" ]
abomonation-serialize = [ "abomonation" ] abomonation-serialize = [ "abomonation" ]
@ -33,32 +34,39 @@ io = [ "pest", "pest_derive" ]
compare = [ "matrixcompare-core" ] compare = [ "matrixcompare-core" ]
libm = [ "simba/libm" ] libm = [ "simba/libm" ]
libm-force = [ "simba/libm_force" ] libm-force = [ "simba/libm_force" ]
proptest-support = [ "proptest" ]
no_unsound_assume_init = [ ]
# This feature is only used for tests, and enables tests that require more time to run
slow-tests = []
[dependencies] [dependencies]
typenum = "1.12" typenum = "1.12"
generic-array = "0.14" generic-array = "0.14"
rand = { version = "0.7", default-features = false } rand = { version = "0.8", default-features = false }
getrandom = { version = "0.2", default-features = false, features = [ "js" ] } # For wasm
num-traits = { version = "0.2", default-features = false } num-traits = { version = "0.2", default-features = false }
num-complex = { version = "0.3", default-features = false } num-complex = { version = "0.3", default-features = false }
num-rational = { version = "0.3", default-features = false } num-rational = { version = "0.3", default-features = false }
approx = { version = "0.4", default-features = false } approx = { version = "0.4", default-features = false }
simba = { version = "0.3", default-features = false } simba = { version = "0.4", default-features = false }
alga = { version = "0.9", default-features = false, optional = true } alga = { version = "0.9", default-features = false, optional = true }
rand_distr = { version = "0.3", optional = true } rand_distr = { version = "0.4", default-features = false, optional = true }
matrixmultiply = { version = "0.2", optional = true } matrixmultiply = { version = "0.3", optional = true }
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true } serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
abomonation = { version = "0.7", optional = true } abomonation = { version = "0.7", optional = true }
mint = { version = "0.5", optional = true } mint = { version = "0.5", optional = true }
quickcheck = { version = "0.9", optional = true } quickcheck = { version = "1", optional = true }
pest = { version = "2", optional = true } pest = { version = "2", optional = true }
pest_derive = { version = "2", optional = true } pest_derive = { version = "2", optional = true }
bytemuck = { version = "1.5", optional = true }
matrixcompare-core = { version = "0.1", optional = true } matrixcompare-core = { version = "0.1", optional = true }
proptest = { version = "1", optional = true, default-features = false, features = ["std"] }
[dev-dependencies] [dev-dependencies]
serde_json = "1.0" serde_json = "1.0"
rand_xorshift = "0.2" rand_xorshift = "0.3"
rand_isaac = "0.2" rand_isaac = "0.3"
### Uncomment this line before running benchmarks. ### Uncomment this line before running benchmarks.
### We can't just let this uncommented because that would break ### We can't just let this uncommented because that would break
### compilation for #[no-std] because of the terrible Cargo bug ### compilation for #[no-std] because of the terrible Cargo bug
@ -66,10 +74,11 @@ rand_isaac = "0.2"
#criterion = "0.2.10" #criterion = "0.2.10"
# For matrix comparison macro # For matrix comparison macro
matrixcompare = "0.1.3" matrixcompare = "0.2.0"
itertools = "0.10"
[workspace] [workspace]
members = [ "nalgebra-lapack", "nalgebra-glm" ] members = [ "nalgebra-lapack", "nalgebra-glm", "nalgebra-sparse" ]
[[bench]] [[bench]]
name = "nalgebra_bench" name = "nalgebra_bench"
@ -78,3 +87,7 @@ path = "benches/lib.rs"
[profile.bench] [profile.bench]
lto = true lto = true
[package.metadata.docs.rs]
# Enable certain features when building docs for docs.rs
features = [ "proptest-support", "compare" ]

View File

@ -17,7 +17,7 @@
</p> </p>
<p align = "center"> <p align = "center">
<strong> <strong>
<a href="https://nalgebra.org">Users guide</a> | <a href="https://nalgebra.org/rustdoc/nalgebra/index.html">Documentation</a> | <a href="https://discourse.nphysics.org/c/nalgebra">Forum</a> <a href="https://nalgebra.org">Users guide</a> | <a href="https://docs.rs/nalgebra/latest/nalgebra/">Documentation</a> | <a href="https://discourse.nphysics.org/c/nalgebra">Forum</a>
</strong> </strong>
</p> </p>

View File

@ -136,6 +136,30 @@ fn mat500_mul_mat500(bench: &mut criterion::Criterion) {
bench.bench_function("mat500_mul_mat500", move |bh| bh.iter(|| &a * &b)); bench.bench_function("mat500_mul_mat500", move |bh| bh.iter(|| &a * &b));
} }
fn iter(bench: &mut criterion::Criterion) {
let a = DMatrix::<f64>::new_random(1000, 1000);
bench.bench_function("iter", move |bh| {
bh.iter(|| {
for value in a.iter() {
criterion::black_box(value);
}
})
});
}
fn iter_rev(bench: &mut criterion::Criterion) {
let a = DMatrix::<f64>::new_random(1000, 1000);
bench.bench_function("iter_rev", move |bh| {
bh.iter(|| {
for value in a.iter().rev() {
criterion::black_box(value);
}
})
});
}
fn copy_from(bench: &mut criterion::Criterion) { fn copy_from(bench: &mut criterion::Criterion) {
let a = DMatrix::<f64>::new_random(1000, 1000); let a = DMatrix::<f64>::new_random(1000, 1000);
let mut b = DMatrix::<f64>::new_random(1000, 1000); let mut b = DMatrix::<f64>::new_random(1000, 1000);
@ -235,6 +259,8 @@ criterion_group!(
mat10_mul_mat10_static, mat10_mul_mat10_static,
mat100_mul_mat100, mat100_mul_mat100,
mat500_mul_mat500, mat500_mul_mat500,
iter,
iter_rev,
copy_from, copy_from,
axpy, axpy,
tr_mul_to, tr_mul_to,

View File

@ -1,22 +1,24 @@
[package] [package]
name = "nalgebra-glm" name = "nalgebra-glm"
version = "0.10.0" version = "0.11.0"
authors = ["sebcrozet <developer@crozet.re>"] authors = ["sebcrozet <developer@crozet.re>"]
description = "A computer-graphics oriented API for nalgebra, inspired by the C++ GLM library." description = "A computer-graphics oriented API for nalgebra, inspired by the C++ GLM library."
documentation = "https://www.nalgebra.org/rustdoc_glm/nalgebra_glm/index.html" documentation = "https://www.nalgebra.org/docs"
homepage = "https://nalgebra.org" homepage = "https://nalgebra.org"
repository = "https://github.com/rustsim/nalgebra" repository = "https://github.com/dimforge/nalgebra"
readme = "../README.md" readme = "../README.md"
categories = [ "science" ] categories = [ "science", "mathematics", "wasm", "no standard library" ]
keywords = [ "linear", "algebra", "matrix", "vector", "math" ] keywords = [ "linear", "algebra", "matrix", "vector", "math" ]
license = "BSD-3-Clause" license = "BSD-3-Clause"
edition = "2018" edition = "2018"
[badges]
maintenance = { status = "actively-developed" }
[features] [features]
default = [ "std" ] default = [ "std" ]
std = [ "nalgebra/std", "simba/std" ] std = [ "nalgebra/std", "simba/std" ]
stdweb = [ "nalgebra/stdweb" ]
arbitrary = [ "nalgebra/arbitrary" ] arbitrary = [ "nalgebra/arbitrary" ]
serde-serialize = [ "nalgebra/serde-serialize" ] serde-serialize = [ "nalgebra/serde-serialize" ]
abomonation-serialize = [ "nalgebra/abomonation-serialize" ] abomonation-serialize = [ "nalgebra/abomonation-serialize" ]
@ -24,5 +26,5 @@ abomonation-serialize = [ "nalgebra/abomonation-serialize" ]
[dependencies] [dependencies]
num-traits = { version = "0.2", default-features = false } num-traits = { version = "0.2", default-features = false }
approx = { version = "0.4", default-features = false } approx = { version = "0.4", default-features = false }
simba = { version = "0.3", default-features = false } simba = { version = "0.4", default-features = false }
nalgebra = { path = "..", version = "0.24", default-features = false } nalgebra = { path = "..", version = "0.25", default-features = false }

View File

@ -1,40 +1,47 @@
[package] [package]
name = "nalgebra-lapack" name = "nalgebra-lapack"
version = "0.15.0" version = "0.16.0"
authors = [ "Sébastien Crozet <developer@crozet.re>", "Andrew Straw <strawman@astraw.com>" ] authors = [ "Sébastien Crozet <developer@crozet.re>", "Andrew Straw <strawman@astraw.com>" ]
description = "Linear algebra library with transformations and satically-sized or dynamically-sized matrices." description = "Matrix decompositions using nalgebra matrices and Lapack bindings."
documentation = "https://nalgebra.org/doc/nalgebra/index.html" documentation = "https://www.nalgebra.org/docs"
homepage = "https://nalgebra.org" homepage = "https://nalgebra.org"
repository = "https://github.com/rustsim/nalgebra" repository = "https://github.com/dimforge/nalgebra"
readme = "README.md" readme = "../README.md"
keywords = [ "linear", "algebra", "matrix", "vector" ] categories = [ "science", "mathematics" ]
license = "BSD-3-Clause" keywords = [ "linear", "algebra", "matrix", "vector", "math", "lapack" ]
edition = "2018" license = "BSD-3-Clause"
edition = "2018"
[badges]
maintenance = { status = "actively-developed" }
[features] [features]
serde-serialize = [ "serde", "serde_derive" ] serde-serialize = [ "serde", "serde_derive" ]
proptest-support = [ "nalgebra/proptest-support" ]
arbitrary = [ "nalgebra/arbitrary" ]
# For BLAS/LAPACK # For BLAS/LAPACK
default = ["openblas"] default = ["netlib"]
openblas = ["lapack-src/openblas"] openblas = ["lapack-src/openblas"]
netlib = ["lapack-src/netlib"] netlib = ["lapack-src/netlib"]
accelerate = ["lapack-src/accelerate"] accelerate = ["lapack-src/accelerate"]
intel-mkl = ["lapack-src/intel-mkl"] intel-mkl = ["lapack-src/intel-mkl"]
[dependencies] [dependencies]
nalgebra = { version = "0.24", path = ".." } nalgebra = { version = "0.25", path = ".." }
num-traits = "0.2" num-traits = "0.2"
num-complex = { version = "0.2", default-features = false } num-complex = { version = "0.3", default-features = false }
simba = "0.2" simba = "0.4"
serde = { version = "1.0", optional = true } serde = { version = "1.0", optional = true }
serde_derive = { version = "1.0", optional = true } serde_derive = { version = "1.0", optional = true }
lapack = { version = "0.16", default-features = false } lapack = { version = "0.17", default-features = false }
lapack-src = { version = "0.5", default-features = false } lapack-src = { version = "0.6", default-features = false }
# clippy = "*" # clippy = "*"
[dev-dependencies] [dev-dependencies]
nalgebra = { version = "0.24", features = [ "arbitrary" ], path = ".." } nalgebra = { version = "0.25", features = [ "arbitrary" ], path = ".." }
quickcheck = "0.9" proptest = { version = "1", default-features = false, features = ["std"] }
approx = "0.3" quickcheck = "1"
rand = "0.7" approx = "0.4"
rand = "0.8"

View File

@ -78,9 +78,9 @@ where
let lda = n as i32; let lda = n as i32;
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
// TODO: Tap into the workspace. // TODO: Tap into the workspace.
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
let mut info = 0; let mut info = 0;
let mut placeholder1 = [N::zero()]; let mut placeholder1 = [N::zero()];
@ -107,8 +107,10 @@ where
match (left_eigenvectors, eigenvectors) { match (left_eigenvectors, eigenvectors) {
(true, true) => { (true, true) => {
let mut vl = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) }; let mut vl =
let mut vr = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) }; unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
let mut vr =
unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
N::xgeev( N::xgeev(
ljob, ljob,
@ -137,7 +139,8 @@ where
} }
} }
(true, false) => { (true, false) => {
let mut vl = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) }; let mut vl =
unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
N::xgeev( N::xgeev(
ljob, ljob,
@ -166,7 +169,8 @@ where
} }
} }
(false, true) => { (false, true) => {
let mut vr = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) }; let mut vr =
unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
N::xgeev( N::xgeev(
ljob, ljob,
@ -243,8 +247,8 @@ where
let lda = n as i32; let lda = n as i32;
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
let mut info = 0; let mut info = 0;
let mut placeholder1 = [N::zero()]; let mut placeholder1 = [N::zero()];
@ -287,7 +291,7 @@ where
); );
lapack_panic!(info); lapack_panic!(info);
let mut res = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut res = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
for i in 0..res.len() { for i in 0..res.len() {
res[i] = Complex::new(wr[i], wi[i]); res[i] = Complex::new(wr[i], wi[i]);

View File

@ -60,7 +60,7 @@ where
"Unable to compute the hessenberg decomposition of an empty matrix." "Unable to compute the hessenberg decomposition of an empty matrix."
); );
let mut tau = unsafe { Matrix::new_uninitialized_generic(nrows.sub(U1), U1) }; let mut tau = unsafe { Matrix::new_uninitialized_generic(nrows.sub(U1), U1).assume_init() };
let mut info = 0; let mut info = 0;
let lwork = let lwork =

View File

@ -57,7 +57,8 @@ where
let (nrows, ncols) = m.data.shape(); let (nrows, ncols) = m.data.shape();
let mut info = 0; let mut info = 0;
let mut tau = unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1) }; let mut tau =
unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1).assume_init() };
if nrows.value() == 0 || ncols.value() == 0 { if nrows.value() == 0 || ncols.value() == 0 {
return Self { qr: m, tau: tau }; return Self { qr: m, tau: tau };

View File

@ -78,9 +78,9 @@ where
let mut info = 0; let mut info = 0;
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
let mut q = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) }; let mut q = unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
// Placeholders: // Placeholders:
let mut bwork = [0i32]; let mut bwork = [0i32];
let mut unused = 0; let mut unused = 0;
@ -151,7 +151,8 @@ where
where where
DefaultAllocator: Allocator<Complex<N>, D>, DefaultAllocator: Allocator<Complex<N>, D>,
{ {
let mut out = unsafe { VectorN::new_uninitialized_generic(self.t.data.shape().0, U1) }; let mut out =
unsafe { VectorN::new_uninitialized_generic(self.t.data.shape().0, U1).assume_init() };
for i in 0..out.len() { for i in 0..out.len() {
out[i] = Complex::new(self.re[i], self.im[i]) out[i] = Complex::new(self.re[i], self.im[i])

View File

@ -99,9 +99,9 @@ macro_rules! svd_impl(
let lda = nrows.value() as i32; let lda = nrows.value() as i32;
let mut u = unsafe { Matrix::new_uninitialized_generic(nrows, nrows) }; let mut u = unsafe { Matrix::new_uninitialized_generic(nrows, nrows).assume_init() };
let mut s = unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1) }; let mut s = unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1).assume_init() };
let mut vt = unsafe { Matrix::new_uninitialized_generic(ncols, ncols) }; let mut vt = unsafe { Matrix::new_uninitialized_generic(ncols, ncols).assume_init() };
let ldu = nrows.value(); let ldu = nrows.value();
let ldvt = ncols.value(); let ldvt = ncols.value();

View File

@ -94,7 +94,7 @@ where
let lda = n as i32; let lda = n as i32;
let mut values = unsafe { Matrix::new_uninitialized_generic(nrows, U1) }; let mut values = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
let mut info = 0; let mut info = 0;
let lwork = N::xsyev_work_size(jobz, b'L', n as i32, m.as_mut_slice(), lda, &mut info); let lwork = N::xsyev_work_size(jobz, b'L', n as i32, m.as_mut_slice(), lda, &mut info);

View File

@ -1,8 +1,14 @@
#[macro_use] #[macro_use]
extern crate approx; extern crate approx;
#[cfg(not(feature = "proptest-support"))]
compile_error!("Tests must be run with `proptest-support`");
extern crate nalgebra as na; extern crate nalgebra as na;
extern crate nalgebra_lapack as nl; extern crate nalgebra_lapack as nl;
#[macro_use]
extern crate quickcheck; extern crate lapack;
extern crate lapack_src;
mod linalg; mod linalg;
#[path = "../../tests/proptest/mod.rs"]
mod proptest;

View File

@ -1,101 +1,90 @@
use std::cmp; use std::cmp;
use na::{DMatrix, DVector, Matrix3, Matrix4, Matrix4x3, Vector4}; use na::{DMatrix, DVector, Matrix4x3, Vector4};
use nl::Cholesky; use nl::Cholesky;
quickcheck! { use crate::proptest::*;
fn cholesky(m: DMatrix<f64>) -> bool { use proptest::{prop_assert, proptest};
if m.len() != 0 {
let m = &m * m.transpose();
if let Some(chol) = Cholesky::new(m.clone()) {
let l = chol.unpack();
let reconstructed_m = &l * l.transpose();
return relative_eq!(reconstructed_m, m, epsilon = 1.0e-7) proptest! {
} #[test]
fn cholesky(m in dmatrix()) {
let m = &m * m.transpose();
if let Some(chol) = Cholesky::new(m.clone()) {
let l = chol.unpack();
let reconstructed_m = &l * l.transpose();
prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7));
} }
return true
} }
fn cholesky_static(m: Matrix3<f64>) -> bool { #[test]
fn cholesky_static(m in matrix3()) {
let m = &m * m.transpose(); let m = &m * m.transpose();
if let Some(chol) = Cholesky::new(m) { if let Some(chol) = Cholesky::new(m) {
let l = chol.unpack(); let l = chol.unpack();
let reconstructed_m = &l * l.transpose(); let reconstructed_m = &l * l.transpose();
relative_eq!(reconstructed_m, m, epsilon = 1.0e-7) prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7))
}
else {
false
} }
} }
fn cholesky_solve(n: usize, nb: usize) -> bool { #[test]
if n != 0 { fn cholesky_solve(n in PROPTEST_MATRIX_DIM, nb in PROPTEST_MATRIX_DIM) {
let n = cmp::min(n, 15); // To avoid slowing down the test too much. let n = cmp::min(n, 15); // To avoid slowing down the test too much.
let nb = cmp::min(nb, 15); // To avoid slowing down the test too much. let nb = cmp::min(nb, 15); // To avoid slowing down the test too much.
let m = DMatrix::<f64>::new_random(n, n); let m = DMatrix::<f64>::new_random(n, n);
let m = &m * m.transpose(); let m = &m * m.transpose();
if let Some(chol) = Cholesky::new(m.clone()) { if let Some(chol) = Cholesky::new(m.clone()) {
let b1 = DVector::new_random(n); let b1 = DVector::new_random(n);
let b2 = DMatrix::new_random(n, nb); let b2 = DMatrix::new_random(n, nb);
let sol1 = chol.solve(&b1).unwrap(); let sol1 = chol.solve(&b1).unwrap();
let sol2 = chol.solve(&b2).unwrap(); let sol2 = chol.solve(&b2).unwrap();
return relative_eq!(&m * sol1, b1, epsilon = 1.0e-6) && prop_assert!(relative_eq!(&m * sol1, b1, epsilon = 1.0e-6));
relative_eq!(&m * sol2, b2, epsilon = 1.0e-6) prop_assert!(relative_eq!(&m * sol2, b2, epsilon = 1.0e-6));
}
} }
return true;
} }
fn cholesky_solve_static(m: Matrix4<f64>) -> bool { #[test]
fn cholesky_solve_static(m in matrix4()) {
let m = &m * m.transpose(); let m = &m * m.transpose();
match Cholesky::new(m) { if let Some(chol) = Cholesky::new(m) {
Some(chol) => { let b1 = Vector4::new_random();
let b1 = Vector4::new_random(); let b2 = Matrix4x3::new_random();
let b2 = Matrix4x3::new_random();
let sol1 = chol.solve(&b1).unwrap(); let sol1 = chol.solve(&b1).unwrap();
let sol2 = chol.solve(&b2).unwrap(); let sol2 = chol.solve(&b2).unwrap();
relative_eq!(m * sol1, b1, epsilon = 1.0e-7) && prop_assert!(relative_eq!(m * sol1, b1, epsilon = 1.0e-7));
relative_eq!(m * sol2, b2, epsilon = 1.0e-7) prop_assert!(relative_eq!(m * sol2, b2, epsilon = 1.0e-7));
},
None => true
} }
} }
fn cholesky_inverse(n: usize) -> bool { #[test]
if n != 0 { fn cholesky_inverse(n in PROPTEST_MATRIX_DIM) {
let n = cmp::min(n, 15); // To avoid slowing down the test too much. let n = cmp::min(n, 15); // To avoid slowing down the test too much.
let m = DMatrix::<f64>::new_random(n, n); let m = DMatrix::<f64>::new_random(n, n);
let m = &m * m.transpose(); let m = &m * m.transpose();
if let Some(m1) = Cholesky::new(m.clone()).unwrap().inverse() { if let Some(m1) = Cholesky::new(m.clone()).unwrap().inverse() {
let id1 = &m * &m1; let id1 = &m * &m1;
let id2 = &m1 * &m; let id2 = &m1 * &m;
return id1.is_identity(1.0e-6) && id2.is_identity(1.0e-6); prop_assert!(id1.is_identity(1.0e-6) && id2.is_identity(1.0e-6));
}
} }
return true;
} }
fn cholesky_inverse_static(m: Matrix4<f64>) -> bool { #[test]
fn cholesky_inverse_static(m in matrix4()) {
let m = m * m.transpose(); let m = m * m.transpose();
match Cholesky::new(m.clone()).unwrap().inverse() { if let Some(m1) = Cholesky::new(m.clone()).unwrap().inverse() {
Some(m1) => { let id1 = &m * &m1;
let id1 = &m * &m1; let id2 = &m1 * &m;
let id2 = &m1 * &m;
id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5) prop_assert!(id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5))
},
None => true
} }
} }
} }

View File

@ -1,38 +1,32 @@
use std::cmp; use std::cmp;
use nl::Hessenberg;
use na::{DMatrix, Matrix4}; use na::{DMatrix, Matrix4};
use nl::Hessenberg;
quickcheck!{ use crate::proptest::*;
fn hessenberg(n: usize) -> bool { use proptest::{prop_assert, proptest};
if n != 0 {
let n = cmp::min(n, 25);
let m = DMatrix::<f64>::new_random(n, n);
match Hessenberg::new(m.clone()) { proptest! {
Some(hess) => { #[test]
let h = hess.h(); fn hessenberg(n in PROPTEST_MATRIX_DIM) {
let p = hess.p(); let n = cmp::min(n, 25);
let m = DMatrix::<f64>::new_random(n, n);
relative_eq!(m, &p * h * p.transpose(), epsilon = 1.0e-7) if let Some(hess) = Hessenberg::new(m.clone()) {
}, let h = hess.h();
None => true let p = hess.p();
}
} prop_assert!(relative_eq!(m, &p * h * p.transpose(), epsilon = 1.0e-7))
else {
true
} }
} }
fn hessenberg_static(m: Matrix4<f64>) -> bool { #[test]
match Hessenberg::new(m) { fn hessenberg_static(m in matrix4()) {
Some(hess) => { if let Some(hess) = Hessenberg::new(m) {
let h = hess.h(); let h = hess.h();
let p = hess.p(); let p = hess.p();
relative_eq!(m, p * h * p.transpose(), epsilon = 1.0e-7) prop_assert!(relative_eq!(m, p * h * p.transpose(), epsilon = 1.0e-7))
},
None => true
} }
} }
} }

View File

@ -1,28 +1,28 @@
use std::cmp; use std::cmp;
use na::{DMatrix, DVector, Matrix3x4, Matrix4, Matrix4x3, Vector4}; use na::{DMatrix, DVector, Matrix4x3, Vector4};
use nl::LU; use nl::LU;
quickcheck! { use crate::proptest::*;
fn lup(m: DMatrix<f64>) -> bool { use proptest::{prop_assert, proptest};
if m.len() != 0 {
let lup = LU::new(m.clone());
let l = lup.l();
let u = lup.u();
let mut computed1 = &l * &u;
lup.permute(&mut computed1);
let computed2 = lup.p() * l * u; proptest! {
#[test]
fn lup(m in dmatrix()) {
let lup = LU::new(m.clone());
let l = lup.l();
let u = lup.u();
let mut computed1 = &l * &u;
lup.permute(&mut computed1);
relative_eq!(computed1, m, epsilon = 1.0e-7) && let computed2 = lup.p() * l * u;
relative_eq!(computed2, m, epsilon = 1.0e-7)
} prop_assert!(relative_eq!(computed1, m, epsilon = 1.0e-7));
else { prop_assert!(relative_eq!(computed2, m, epsilon = 1.0e-7));
true
}
} }
fn lu_static(m: Matrix3x4<f64>) -> bool { #[test]
fn lu_static(m in matrix3x5()) {
let lup = LU::new(m); let lup = LU::new(m);
let l = lup.l(); let l = lup.l();
let u = lup.u(); let u = lup.u();
@ -31,37 +31,34 @@ quickcheck! {
let computed2 = lup.p() * l * u; let computed2 = lup.p() * l * u;
relative_eq!(computed1, m, epsilon = 1.0e-7) && prop_assert!(relative_eq!(computed1, m, epsilon = 1.0e-7));
relative_eq!(computed2, m, epsilon = 1.0e-7) prop_assert!(relative_eq!(computed2, m, epsilon = 1.0e-7));
} }
fn lu_solve(n: usize, nb: usize) -> bool { #[test]
if n != 0 { fn lu_solve(n in PROPTEST_MATRIX_DIM, nb in PROPTEST_MATRIX_DIM) {
let n = cmp::min(n, 25); // To avoid slowing down the test too much. let n = cmp::min(n, 25); // To avoid slowing down the test too much.
let nb = cmp::min(nb, 25); // To avoid slowing down the test too much. let nb = cmp::min(nb, 25); // To avoid slowing down the test too much.
let m = DMatrix::<f64>::new_random(n, n); let m = DMatrix::<f64>::new_random(n, n);
let lup = LU::new(m.clone()); let lup = LU::new(m.clone());
let b1 = DVector::new_random(n); let b1 = DVector::new_random(n);
let b2 = DMatrix::new_random(n, nb); let b2 = DMatrix::new_random(n, nb);
let sol1 = lup.solve(&b1).unwrap(); let sol1 = lup.solve(&b1).unwrap();
let sol2 = lup.solve(&b2).unwrap(); let sol2 = lup.solve(&b2).unwrap();
let tr_sol1 = lup.solve_transpose(&b1).unwrap(); let tr_sol1 = lup.solve_transpose(&b1).unwrap();
let tr_sol2 = lup.solve_transpose(&b2).unwrap(); let tr_sol2 = lup.solve_transpose(&b2).unwrap();
relative_eq!(&m * sol1, b1, epsilon = 1.0e-7) && prop_assert!(relative_eq!(&m * sol1, b1, epsilon = 1.0e-7));
relative_eq!(&m * sol2, b2, epsilon = 1.0e-7) && prop_assert!(relative_eq!(&m * sol2, b2, epsilon = 1.0e-7));
relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7) && prop_assert!(relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7));
relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7) prop_assert!(relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7));
}
else {
true
}
} }
fn lu_solve_static(m: Matrix4<f64>) -> bool { #[test]
fn lu_solve_static(m in matrix4()) {
let lup = LU::new(m); let lup = LU::new(m);
let b1 = Vector4::new_random(); let b1 = Vector4::new_random();
let b2 = Matrix4x3::new_random(); let b2 = Matrix4x3::new_random();
@ -71,37 +68,32 @@ quickcheck! {
let tr_sol1 = lup.solve_transpose(&b1).unwrap(); let tr_sol1 = lup.solve_transpose(&b1).unwrap();
let tr_sol2 = lup.solve_transpose(&b2).unwrap(); let tr_sol2 = lup.solve_transpose(&b2).unwrap();
relative_eq!(m * sol1, b1, epsilon = 1.0e-7) && prop_assert!(relative_eq!(m * sol1, b1, epsilon = 1.0e-7));
relative_eq!(m * sol2, b2, epsilon = 1.0e-7) && prop_assert!(relative_eq!(m * sol2, b2, epsilon = 1.0e-7));
relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7) && prop_assert!(relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7));
relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7) prop_assert!(relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7));
} }
fn lu_inverse(n: usize) -> bool { #[test]
if n != 0 { fn lu_inverse(n in PROPTEST_MATRIX_DIM) {
let n = cmp::min(n, 25); // To avoid slowing down the test too much. let n = cmp::min(n, 25); // To avoid slowing down the test too much.
let m = DMatrix::<f64>::new_random(n, n); let m = DMatrix::<f64>::new_random(n, n);
if let Some(m1) = LU::new(m.clone()).inverse() { if let Some(m1) = LU::new(m.clone()).inverse() {
let id1 = &m * &m1; let id1 = &m * &m1;
let id2 = &m1 * &m; let id2 = &m1 * &m;
return id1.is_identity(1.0e-7) && id2.is_identity(1.0e-7); prop_assert!(id1.is_identity(1.0e-7) && id2.is_identity(1.0e-7));
}
} }
return true;
} }
fn lu_inverse_static(m: Matrix4<f64>) -> bool { #[test]
match LU::new(m.clone()).inverse() { fn lu_inverse_static(m in matrix4()) {
Some(m1) => { if let Some(m1) = LU::new(m.clone()).inverse() {
let id1 = &m * &m1; let id1 = &m * &m1;
let id2 = &m1 * &m; let id2 = &m1 * &m;
id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5) prop_assert!(id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5))
},
None => true
} }
} }
} }

View File

@ -1,20 +1,24 @@
use na::{DMatrix, Matrix4x3};
use nl::QR; use nl::QR;
quickcheck! { use crate::proptest::*;
fn qr(m: DMatrix<f64>) -> bool { use proptest::{prop_assert, proptest};
proptest! {
#[test]
fn qr(m in dmatrix()) {
let qr = QR::new(m.clone()); let qr = QR::new(m.clone());
let q = qr.q(); let q = qr.q();
let r = qr.r(); let r = qr.r();
relative_eq!(m, q * r, epsilon = 1.0e-7) prop_assert!(relative_eq!(m, q * r, epsilon = 1.0e-7))
} }
fn qr_static(m: Matrix4x3<f64>) -> bool { #[test]
fn qr_static(m in matrix5x3()) {
let qr = QR::new(m); let qr = QR::new(m);
let q = qr.q(); let q = qr.q();
let r = qr.r(); let r = qr.r();
relative_eq!(m, q * r, epsilon = 1.0e-7) prop_assert!(relative_eq!(m, q * r, epsilon = 1.0e-7))
} }
} }

View File

@ -3,46 +3,40 @@ use std::cmp;
use na::{DMatrix, Matrix4}; use na::{DMatrix, Matrix4};
use nl::Eigen; use nl::Eigen;
quickcheck! { use crate::proptest::*;
fn eigensystem(n: usize) -> bool { use proptest::{prop_assert, proptest};
if n != 0 {
let n = cmp::min(n, 25);
let m = DMatrix::<f64>::new_random(n, n);
match Eigen::new(m.clone(), true, true) { proptest! {
Some(eig) => { #[test]
let eigvals = DMatrix::from_diagonal(&eig.eigenvalues); fn eigensystem(n in PROPTEST_MATRIX_DIM) {
let transformed_eigvectors = &m * eig.eigenvectors.as_ref().unwrap(); let n = cmp::min(n, 25);
let scaled_eigvectors = eig.eigenvectors.as_ref().unwrap() * &eigvals; let m = DMatrix::<f64>::new_random(n, n);
let transformed_left_eigvectors = m.transpose() * eig.left_eigenvectors.as_ref().unwrap(); if let Some(eig) = Eigen::new(m.clone(), true, true) {
let scaled_left_eigvectors = eig.left_eigenvectors.as_ref().unwrap() * &eigvals; let eigvals = DMatrix::from_diagonal(&eig.eigenvalues);
let transformed_eigvectors = &m * eig.eigenvectors.as_ref().unwrap();
let scaled_eigvectors = eig.eigenvectors.as_ref().unwrap() * &eigvals;
relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7) && let transformed_left_eigvectors = m.transpose() * eig.left_eigenvectors.as_ref().unwrap();
relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7) let scaled_left_eigvectors = eig.left_eigenvectors.as_ref().unwrap() * &eigvals;
},
None => true prop_assert!(relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7));
} prop_assert!(relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7));
}
else {
true
} }
} }
fn eigensystem_static(m: Matrix4<f64>) -> bool { #[test]
match Eigen::new(m, true, true) { fn eigensystem_static(m in matrix4()) {
Some(eig) => { if let Some(eig) = Eigen::new(m, true, true) {
let eigvals = Matrix4::from_diagonal(&eig.eigenvalues); let eigvals = Matrix4::from_diagonal(&eig.eigenvalues);
let transformed_eigvectors = m * eig.eigenvectors.unwrap(); let transformed_eigvectors = m * eig.eigenvectors.unwrap();
let scaled_eigvectors = eig.eigenvectors.unwrap() * eigvals; let scaled_eigvectors = eig.eigenvectors.unwrap() * eigvals;
let transformed_left_eigvectors = m.transpose() * eig.left_eigenvectors.unwrap(); let transformed_left_eigvectors = m.transpose() * eig.left_eigenvectors.unwrap();
let scaled_left_eigvectors = eig.left_eigenvectors.unwrap() * eigvals; let scaled_left_eigvectors = eig.left_eigenvectors.unwrap() * eigvals;
relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7) && prop_assert!(relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7));
relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7) prop_assert!(relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7));
},
None => true
} }
} }
} }

View File

@ -1,20 +1,24 @@
use na::{DMatrix, Matrix4}; use na::DMatrix;
use nl::Schur; use nl::Schur;
use std::cmp; use std::cmp;
quickcheck! { use crate::proptest::*;
fn schur(n: usize) -> bool { use proptest::{prop_assert, proptest};
proptest! {
#[test]
fn schur(n in PROPTEST_MATRIX_DIM) {
let n = cmp::max(1, cmp::min(n, 10)); let n = cmp::max(1, cmp::min(n, 10));
let m = DMatrix::<f64>::new_random(n, n); let m = DMatrix::<f64>::new_random(n, n);
let (vecs, vals) = Schur::new(m.clone()).unpack(); let (vecs, vals) = Schur::new(m.clone()).unpack();
relative_eq!(&vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7) prop_assert!(relative_eq!(&vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7))
} }
fn schur_static(m: Matrix4<f64>) -> bool { #[test]
fn schur_static(m in matrix4()) {
let (vecs, vals) = Schur::new(m.clone()).unpack(); let (vecs, vals) = Schur::new(m.clone()).unpack();
prop_assert!(relative_eq!(vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7))
relative_eq!(vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7)
} }
} }

View File

@ -1,57 +1,53 @@
use na::{DMatrix, Matrix3x4}; use na::{DMatrix, Matrix3x5};
use nl::SVD; use nl::SVD;
quickcheck! { use crate::proptest::*;
fn svd(m: DMatrix<f64>) -> bool { use proptest::{prop_assert, proptest};
if m.nrows() != 0 && m.ncols() != 0 {
let svd = SVD::new(m.clone()).unwrap();
let sm = DMatrix::from_partial_diagonal(m.nrows(), m.ncols(), svd.singular_values.as_slice());
let reconstructed_m = &svd.u * sm * &svd.vt; proptest! {
let reconstructed_m2 = svd.recompose(); #[test]
fn svd(m in dmatrix()) {
let svd = SVD::new(m.clone()).unwrap();
let sm = DMatrix::from_partial_diagonal(m.nrows(), m.ncols(), svd.singular_values.as_slice());
relative_eq!(reconstructed_m, m, epsilon = 1.0e-7) && let reconstructed_m = &svd.u * sm * &svd.vt;
relative_eq!(reconstructed_m2, reconstructed_m, epsilon = 1.0e-7) let reconstructed_m2 = svd.recompose();
}
else { prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7));
true prop_assert!(relative_eq!(reconstructed_m2, reconstructed_m, epsilon = 1.0e-7));
}
} }
fn svd_static(m: Matrix3x4<f64>) -> bool { #[test]
fn svd_static(m in matrix3x5()) {
let svd = SVD::new(m).unwrap(); let svd = SVD::new(m).unwrap();
let sm = Matrix3x4::from_partial_diagonal(svd.singular_values.as_slice()); let sm = Matrix3x5::from_partial_diagonal(svd.singular_values.as_slice());
let reconstructed_m = &svd.u * &sm * &svd.vt; let reconstructed_m = &svd.u * &sm * &svd.vt;
let reconstructed_m2 = svd.recompose(); let reconstructed_m2 = svd.recompose();
relative_eq!(reconstructed_m, m, epsilon = 1.0e-7) && prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7));
relative_eq!(reconstructed_m2, m, epsilon = 1.0e-7) prop_assert!(relative_eq!(reconstructed_m2, m, epsilon = 1.0e-7));
} }
fn pseudo_inverse(m: DMatrix<f64>) -> bool { #[test]
if m.nrows() == 0 || m.ncols() == 0 { fn pseudo_inverse(m in dmatrix()) {
return true;
}
let svd = SVD::new(m.clone()).unwrap(); let svd = SVD::new(m.clone()).unwrap();
let im = svd.pseudo_inverse(1.0e-7); let im = svd.pseudo_inverse(1.0e-7);
if m.nrows() <= m.ncols() { if m.nrows() <= m.ncols() {
return (&m * &im).is_identity(1.0e-7) prop_assert!((&m * &im).is_identity(1.0e-7));
} }
if m.nrows() >= m.ncols() { if m.nrows() >= m.ncols() {
return (im * m).is_identity(1.0e-7) prop_assert!((im * m).is_identity(1.0e-7));
} }
return true;
} }
fn pseudo_inverse_static(m: Matrix3x4<f64>) -> bool { #[test]
fn pseudo_inverse_static(m in matrix3x5()) {
let svd = SVD::new(m).unwrap(); let svd = SVD::new(m).unwrap();
let im = svd.pseudo_inverse(1.0e-7); let im = svd.pseudo_inverse(1.0e-7);
(m * im).is_identity(1.0e-7) prop_assert!((m * im).is_identity(1.0e-7))
} }
} }

View File

@ -1,20 +1,25 @@
use std::cmp; use std::cmp;
use na::{DMatrix, Matrix4}; use na::DMatrix;
use nl::SymmetricEigen; use nl::SymmetricEigen;
quickcheck! { use crate::proptest::*;
fn symmetric_eigen(n: usize) -> bool { use proptest::{prop_assert, proptest};
proptest! {
#[test]
fn symmetric_eigen(n in PROPTEST_MATRIX_DIM) {
let n = cmp::max(1, cmp::min(n, 10)); let n = cmp::max(1, cmp::min(n, 10));
let m = DMatrix::<f64>::new_random(n, n); let m = DMatrix::<f64>::new_random(n, n);
let eig = SymmetricEigen::new(m.clone()); let eig = SymmetricEigen::new(m.clone());
let recomp = eig.recompose(); let recomp = eig.recompose();
relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5) prop_assert!(relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5))
} }
fn symmetric_eigen_static(m: Matrix4<f64>) -> bool { #[test]
fn symmetric_eigen_static(m in matrix4()) {
let eig = SymmetricEigen::new(m); let eig = SymmetricEigen::new(m);
let recomp = eig.recompose(); let recomp = eig.recompose();
relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5) prop_assert!(relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5))
} }
} }

View File

@ -0,0 +1,27 @@
[package]
name = "nalgebra-sparse"
version = "0.1.0"
authors = [ "Andreas Longva", "Sébastien Crozet <developer@crozet.re>" ]
edition = "2018"
[features]
proptest-support = ["proptest", "nalgebra/proptest-support"]
compare = [ "matrixcompare-core" ]
# Enable to enable running some tests that take a lot of time to run
slow-tests = []
[dependencies]
nalgebra = { version="0.25", path = "../" }
num-traits = { version = "0.2", default-features = false }
proptest = { version = "1.0", optional = true }
matrixcompare-core = { version = "0.1.0", optional = true }
[dev-dependencies]
itertools = "0.10"
matrixcompare = { version = "0.2.0", features = [ "proptest-support" ] }
nalgebra = { version="0.25", path = "../", features = ["compare"] }
[package.metadata.docs.rs]
# Enable certain features when building docs for docs.rs
features = [ "proptest-support", "compare" ]

View File

@ -0,0 +1,124 @@
use crate::convert::serial::*;
use crate::coo::CooMatrix;
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use nalgebra::storage::Storage;
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
use num_traits::Zero;
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>,
{
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_coo(matrix)
}
}
impl<'a, T> From<&'a CooMatrix<T>> for DMatrix<T>
where
T: Scalar + Zero + ClosedAdd,
{
fn from(coo: &'a CooMatrix<T>) -> Self {
convert_coo_dense(coo)
}
}
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
where
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CooMatrix<T>) -> Self {
convert_coo_csr(matrix)
}
}
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
where
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_coo(matrix)
}
}
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CsrMatrix<T>
where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>,
{
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_csr(matrix)
}
}
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
where
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_dense(matrix)
}
}
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
where
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CooMatrix<T>) -> Self {
convert_coo_csc(matrix)
}
}
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
where
T: Scalar + Zero,
{
fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_coo(matrix)
}
}
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>,
{
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
convert_dense_csc(matrix)
}
}
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
where
T: Scalar + Zero + ClosedAdd,
{
fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_dense(matrix)
}
}
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
where
T: Scalar,
{
fn from(matrix: &'a CscMatrix<T>) -> Self {
convert_csc_csr(matrix)
}
}
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
where
T: Scalar,
{
fn from(matrix: &'a CsrMatrix<T>) -> Self {
convert_csr_csc(matrix)
}
}

View File

@ -0,0 +1,40 @@
//! Routines for converting between sparse matrix formats.
//!
//! Most users should instead use the provided `From` implementations to convert between matrix
//! formats. Note that `From` implementations may not be available between all combinations of
//! sparse matrices.
//!
//! The following example illustrates how to convert between matrix formats with the `From`
//! implementations.
//!
//! ```rust
//! use nalgebra_sparse::{csr::CsrMatrix, csc::CscMatrix, coo::CooMatrix};
//! use nalgebra::DMatrix;
//!
//! // Conversion from dense
//! let dense = DMatrix::<f64>::identity(9, 8);
//! let csr = CsrMatrix::from(&dense);
//! let csc = CscMatrix::from(&dense);
//! let coo = CooMatrix::from(&dense);
//!
//! // CSR <-> CSC
//! let _ = CsrMatrix::from(&csc);
//! let _ = CscMatrix::from(&csr);
//!
//! // CSR <-> COO
//! let _ = CooMatrix::from(&csr);
//! let _ = CsrMatrix::from(&coo);
//!
//! // CSC <-> COO
//! let _ = CooMatrix::from(&csc);
//! let _ = CscMatrix::from(&coo);
//! ```
//!
//! The routines available here are able to provide more specialized APIs, giving
//! more control over the conversion process. The routines are organized by backends.
//! Currently, only the [`serial`] backend is available.
//! In the future, backends that offer parallel routines may become available.
pub mod serial;
mod impl_std_ops;

View File

@ -0,0 +1,427 @@
//! Serial routines for converting between matrix formats.
//!
//! All routines in this module are single-threaded. At present these routines offer no
//! advantage over using the [`From`] trait, but future changes to the API might offer more
//! control to the user.
use std::ops::Add;
use num_traits::Zero;
use nalgebra::storage::Storage;
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
use crate::coo::CooMatrix;
use crate::cs;
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
/// Converts a dense matrix to [`CooMatrix`].
pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>,
{
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
for (index, v) in dense.iter().enumerate() {
if v != &T::zero() {
// We use the fact that matrix iteration is guaranteed to be column-major
let i = index % dense.nrows();
let j = index / dense.nrows();
coo.push(i, j, v.inlined_clone());
}
}
coo
}
/// Converts a [`CooMatrix`] to a dense matrix.
pub fn convert_coo_dense<T>(coo: &CooMatrix<T>) -> DMatrix<T>
where
T: Scalar + Zero + ClosedAdd,
{
let mut output = DMatrix::repeat(coo.nrows(), coo.ncols(), T::zero());
for (i, j, v) in coo.triplet_iter() {
output[(i, j)] += v.inlined_clone();
}
output
}
/// Converts a [`CooMatrix`] to a [`CsrMatrix`].
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
where
T: Scalar + Zero,
{
let (offsets, indices, values) = convert_coo_cs(
coo.nrows(),
coo.row_indices(),
coo.col_indices(),
coo.values(),
);
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
// to see if it can be justified for performance reasons)
CsrMatrix::try_from_csr_data(coo.nrows(), coo.ncols(), offsets, indices, values)
.expect("Internal error: Invalid CSR data during COO->CSR conversion")
}
/// Converts a [`CsrMatrix`] to a [`CooMatrix`].
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
for (i, j, v) in csr.triplet_iter() {
result.push(i, j, v.inlined_clone());
}
result
}
/// Converts a [`CsrMatrix`] to a dense matrix.
pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
where
T: Scalar + ClosedAdd + Zero,
{
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
for (i, j, v) in csr.triplet_iter() {
output[(i, j)] += v.inlined_clone();
}
output
}
/// Converts a dense matrix to a [`CsrMatrix`].
pub fn convert_dense_csr<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CsrMatrix<T>
where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>,
{
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
let mut col_idx = Vec::new();
let mut values = Vec::new();
// We have to iterate row-by-row to build the CSR matrix, which is at odds with
// nalgebra's column-major storage. The alternative would be to perform an initial sweep
// to count number of non-zeros per row.
row_offsets.push(0);
for i in 0..dense.nrows() {
for j in 0..dense.ncols() {
let v = dense.index((i, j));
if v != &T::zero() {
col_idx.push(j);
values.push(v.inlined_clone());
}
}
row_offsets.push(col_idx.len());
}
// TODO: Consider circumventing the data validity check here
// (would require unsafe, should benchmark)
CsrMatrix::try_from_csr_data(dense.nrows(), dense.ncols(), row_offsets, col_idx, values)
.expect("Internal error: Invalid CsrMatrix format during dense-> CSR conversion")
}
/// Converts a [`CooMatrix`] to a [`CscMatrix`].
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
where
T: Scalar + Zero,
{
let (offsets, indices, values) = convert_coo_cs(
coo.ncols(),
coo.col_indices(),
coo.row_indices(),
coo.values(),
);
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
// to see if it can be justified for performance reasons)
CscMatrix::try_from_csc_data(coo.nrows(), coo.ncols(), offsets, indices, values)
.expect("Internal error: Invalid CSC data during COO->CSC conversion")
}
/// Converts a [`CscMatrix`] to a [`CooMatrix`].
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
where
T: Scalar,
{
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
for (i, j, v) in csc.triplet_iter() {
coo.push(i, j, v.inlined_clone());
}
coo
}
/// Converts a [`CscMatrix`] to a dense matrix.
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
where
T: Scalar + ClosedAdd + Zero,
{
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
for (i, j, v) in csc.triplet_iter() {
output[(i, j)] += v.inlined_clone();
}
output
}
/// Converts a dense matrix to a [`CscMatrix`].
pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
where
T: Scalar + Zero,
R: Dim,
C: Dim,
S: Storage<T, R, C>,
{
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
let mut row_idx = Vec::new();
let mut values = Vec::new();
col_offsets.push(0);
for j in 0..dense.ncols() {
for i in 0..dense.nrows() {
let v = dense.index((i, j));
if v != &T::zero() {
row_idx.push(i);
values.push(v.inlined_clone());
}
}
col_offsets.push(row_idx.len());
}
// TODO: Consider circumventing the data validity check here
// (would require unsafe, should benchmark)
CscMatrix::try_from_csc_data(dense.nrows(), dense.ncols(), col_offsets, row_idx, values)
.expect("Internal error: Invalid CscMatrix format during dense-> CSC conversion")
}
/// Converts a [`CsrMatrix`] to a [`CscMatrix`].
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
where
T: Scalar,
{
let (offsets, indices, values) = cs::transpose_cs(
csr.nrows(),
csr.ncols(),
csr.row_offsets(),
csr.col_indices(),
csr.values(),
);
// TODO: Avoid data validity check?
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
.expect("Internal error: Invalid CSC data during CSR->CSC conversion")
}
/// Converts a [`CscMatrix`] to a [`CsrMatrix`].
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
where
T: Scalar,
{
let (offsets, indices, values) = cs::transpose_cs(
csc.ncols(),
csc.nrows(),
csc.col_offsets(),
csc.row_indices(),
csc.values(),
);
// TODO: Avoid data validity check?
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
.expect("Internal error: Invalid CSR data during CSC->CSR conversion")
}
fn convert_coo_cs<T>(
major_dim: usize,
major_indices: &[usize],
minor_indices: &[usize],
values: &[T],
) -> (Vec<usize>, Vec<usize>, Vec<T>)
where
T: Scalar + Zero,
{
assert_eq!(major_indices.len(), minor_indices.len());
assert_eq!(minor_indices.len(), values.len());
let nnz = major_indices.len();
let (unsorted_major_offsets, unsorted_minor_idx, unsorted_vals) = {
let mut offsets = vec![0usize; major_dim + 1];
let mut minor_idx = vec![0usize; nnz];
let mut vals = vec![T::zero(); nnz];
coo_to_unsorted_cs(
&mut offsets,
&mut minor_idx,
&mut vals,
major_dim,
major_indices,
minor_indices,
values,
);
(offsets, minor_idx, vals)
};
// TODO: If input is sorted and/or without duplicates, we can avoid additional allocations
// and work. Might want to take advantage of this.
// At this point, assembly is essentially complete. However, we must ensure
// that minor indices are sorted within each lane and without duplicates.
let mut sorted_major_offsets = Vec::new();
let mut sorted_minor_idx = Vec::new();
let mut sorted_vals = Vec::new();
sorted_major_offsets.push(0);
// We need some temporary storage when working with each lane. Since lanes often have a
// very small number of non-zero entries, we try to amortize allocations across
// lanes by reusing workspace vectors
let mut idx_workspace = Vec::new();
let mut perm_workspace = Vec::new();
let mut values_workspace = Vec::new();
for lane in 0..major_dim {
let begin = unsorted_major_offsets[lane];
let end = unsorted_major_offsets[lane + 1];
let count = end - begin;
let range = begin..end;
// Ensure that workspaces can hold enough data
perm_workspace.resize(count, 0);
idx_workspace.resize(count, 0);
values_workspace.resize(count, T::zero());
sort_lane(
&mut idx_workspace[..count],
&mut values_workspace[..count],
&unsorted_minor_idx[range.clone()],
&unsorted_vals[range.clone()],
&mut perm_workspace[..count],
);
let sorted_ja_current_len = sorted_minor_idx.len();
combine_duplicates(
|idx| sorted_minor_idx.push(idx),
|val| sorted_vals.push(val),
&idx_workspace[..count],
&values_workspace[..count],
&Add::add,
);
let new_col_count = sorted_minor_idx.len() - sorted_ja_current_len;
sorted_major_offsets.push(sorted_major_offsets.last().unwrap() + new_col_count);
}
(sorted_major_offsets, sorted_minor_idx, sorted_vals)
}
/// Converts matrix data given in triplet format to unsorted CSR/CSC, retaining any duplicated
/// indices.
///
/// Here `major/minor` is `row/col` for CSR and `col/row` for CSC.
fn coo_to_unsorted_cs<T: Clone>(
major_offsets: &mut [usize],
cs_minor_idx: &mut [usize],
cs_values: &mut [T],
major_dim: usize,
major_indices: &[usize],
minor_indices: &[usize],
coo_values: &[T],
) {
assert_eq!(major_offsets.len(), major_dim + 1);
assert_eq!(cs_minor_idx.len(), cs_values.len());
assert_eq!(cs_values.len(), major_indices.len());
assert_eq!(major_indices.len(), minor_indices.len());
assert_eq!(minor_indices.len(), coo_values.len());
// Count the number of occurrences of each row
for major_idx in major_indices {
major_offsets[*major_idx] += 1;
}
cs::convert_counts_to_offsets(major_offsets);
{
// TODO: Instead of allocating a whole new vector storing the current counts,
// I think it's possible to be a bit more clever by storing each count
// in the last of the column indices for each row
let mut current_counts = vec![0usize; major_dim + 1];
let triplet_iter = major_indices.iter().zip(minor_indices).zip(coo_values);
for ((i, j), value) in triplet_iter {
let current_offset = major_offsets[*i] + current_counts[*i];
cs_minor_idx[current_offset] = *j;
cs_values[current_offset] = value.clone();
current_counts[*i] += 1;
}
}
}
/// Sort the indices of the given lane.
///
/// The indices and values in `minor_idx` and `values` are sorted according to the
/// minor indices and stored in `minor_idx_result` and `values_result` respectively.
///
/// All input slices are expected to be of the same length. The contents of mutable slices
/// can be arbitrary, as they are anyway overwritten.
fn sort_lane<T: Clone>(
minor_idx_result: &mut [usize],
values_result: &mut [T],
minor_idx: &[usize],
values: &[T],
workspace: &mut [usize],
) {
assert_eq!(minor_idx_result.len(), values_result.len());
assert_eq!(values_result.len(), minor_idx.len());
assert_eq!(minor_idx.len(), values.len());
assert_eq!(values.len(), workspace.len());
let permutation = workspace;
// Set permutation to identity
for (i, p) in permutation.iter_mut().enumerate() {
*p = i;
}
// Compute permutation needed to bring minor indices into sorted order
// Note: Using sort_unstable here avoids internal allocations, which is crucial since
// each lane might have a small number of elements
permutation.sort_unstable_by_key(|idx| minor_idx[*idx]);
apply_permutation(minor_idx_result, minor_idx, permutation);
apply_permutation(values_result, values, permutation);
}
// TODO: Move this into `utils` or something?
fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
assert_eq!(out_slice.len(), in_slice.len());
assert_eq!(out_slice.len(), permutation.len());
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
*out_element = in_slice[*old_pos].clone();
}
}
/// Given *sorted* indices and corresponding scalar values, combines duplicates with the given
/// associative combiner and calls the provided produce methods with combined indices and values.
fn combine_duplicates<T: Clone>(
mut produce_idx: impl FnMut(usize),
mut produce_value: impl FnMut(T),
idx_array: &[usize],
values: &[T],
combiner: impl Fn(T, T) -> T,
) {
assert_eq!(idx_array.len(), values.len());
let mut i = 0;
while i < idx_array.len() {
let idx = idx_array[i];
let mut combined_value = values[i].clone();
let mut j = i + 1;
while j < idx_array.len() && idx_array[j] == idx {
let j_val = values[j].clone();
combined_value = combiner(combined_value, j_val);
j += 1;
}
produce_idx(idx);
produce_value(combined_value);
i = j;
}
}

208
nalgebra-sparse/src/coo.rs Normal file
View File

@ -0,0 +1,208 @@
//! An implementation of the COO sparse matrix format.
use crate::SparseFormatError;
/// A COO representation of a sparse matrix.
///
/// A COO matrix stores entries in coordinate-form, that is triplets `(i, j, v)`, where `i` and `j`
/// correspond to row and column indices of the entry, and `v` to the value of the entry.
/// The format is of limited use for standard matrix operations. Its main purpose is to facilitate
/// easy construction of other, more efficient matrix formats (such as CSR/COO), and the
/// conversion between different formats.
///
/// # Format
///
/// For given dimensions `nrows` and `ncols`, the matrix is represented by three same-length
/// arrays `row_indices`, `col_indices` and `values` that constitute the coordinate triplets
/// of the matrix. The indices must be in bounds, but *duplicate entries are explicitly allowed*.
/// Upon conversion to other formats, the duplicate entries may be summed together. See the
/// documentation for the respective conversion functions.
///
/// # Examples
///
/// ```rust
/// use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix, csc::CscMatrix};
///
/// // Initialize a matrix with all zeros (no explicitly stored entries).
/// let mut coo = CooMatrix::new(4, 4);
/// // Or initialize it with a set of triplets
/// coo = CooMatrix::try_from_triplets(4, 4, vec![1, 2], vec![0, 1], vec![3.0, 4.0]).unwrap();
///
/// // Push a few triplets
/// coo.push(2, 0, 1.0);
/// coo.push(0, 1, 2.0);
///
/// // Convert to other matrix formats
/// let csr = CsrMatrix::from(&coo);
/// let csc = CscMatrix::from(&coo);
/// ```
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CooMatrix<T> {
nrows: usize,
ncols: usize,
row_indices: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
}
impl<T> CooMatrix<T> {
/// Construct a zero COO matrix of the given dimensions.
///
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
/// is empty, so that the matrix (implicitly) represented by the COO matrix consists of all
/// zero entries.
pub fn new(nrows: usize, ncols: usize) -> Self {
Self {
nrows,
ncols,
row_indices: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
}
}
/// Construct a zero COO matrix of the given dimensions.
///
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
/// is empty, so that the matrix (implicitly) represented by the COO matrix consists of all
/// zero entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self::new(nrows, ncols)
}
/// Try to construct a COO matrix from the given dimensions and a collection of
/// (i, j, v) triplets.
///
/// Returns an error if either row or column indices contain indices out of bounds,
/// or if the data arrays do not all have the same length. Note that the COO format
/// inherently supports duplicate entries.
pub fn try_from_triplets(
nrows: usize,
ncols: usize,
row_indices: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
use crate::SparseFormatErrorKind::*;
if row_indices.len() != col_indices.len() {
return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure,
"Number of row and col indices must be the same.",
));
} else if col_indices.len() != values.len() {
return Err(SparseFormatError::from_kind_and_msg(
InvalidStructure,
"Number of col indices and values must be the same.",
));
}
let row_indices_in_bounds = row_indices.iter().all(|i| *i < nrows);
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
if !row_indices_in_bounds {
Err(SparseFormatError::from_kind_and_msg(
IndexOutOfBounds,
"Row index out of bounds.",
))
} else if !col_indices_in_bounds {
Err(SparseFormatError::from_kind_and_msg(
IndexOutOfBounds,
"Col index out of bounds.",
))
} else {
Ok(Self {
nrows,
ncols,
row_indices,
col_indices,
values,
})
}
}
/// An iterator over triplets (i, j, v).
// TODO: Consider giving the iterator a concrete type instead of impl trait...?
pub fn triplet_iter(&self) -> impl Iterator<Item = (usize, usize, &T)> {
self.row_indices
.iter()
.zip(&self.col_indices)
.zip(&self.values)
.map(|((i, j), v)| (*i, *j, v))
}
/// Push a single triplet to the matrix.
///
/// This adds the value `v` to the `i`th row and `j`th column in the matrix.
///
/// Panics
/// ------
///
/// Panics if `i` or `j` is out of bounds.
#[inline]
pub fn push(&mut self, i: usize, j: usize, v: T) {
assert!(i < self.nrows);
assert!(j < self.ncols);
self.row_indices.push(i);
self.col_indices.push(j);
self.values.push(v);
}
/// The number of rows in the matrix.
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
/// The number of columns in the matrix.
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
/// The number of explicitly stored entries in the matrix.
///
/// This number *includes* duplicate entries. For example, if the `CooMatrix` contains duplicate
/// entries, then it may have a different number of non-zeros as reported by `nnz()` compared
/// to its CSR representation.
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
/// The row indices of the explicitly stored entries.
pub fn row_indices(&self) -> &[usize] {
&self.row_indices
}
/// The column indices of the explicitly stored entries.
pub fn col_indices(&self) -> &[usize] {
&self.col_indices
}
/// The values of the explicitly stored entries.
pub fn values(&self) -> &[T] {
&self.values
}
/// Disassembles the matrix into individual triplet arrays.
///
/// Examples
/// --------
///
/// ```
/// # use nalgebra_sparse::coo::CooMatrix;
/// let row_indices = vec![0, 1];
/// let col_indices = vec![1, 2];
/// let values = vec![1.0, 2.0];
/// let coo = CooMatrix::try_from_triplets(2, 3, row_indices, col_indices, values)
/// .unwrap();
///
/// let (row_idx, col_idx, val) = coo.disassemble();
/// assert_eq!(row_idx, vec![0, 1]);
/// assert_eq!(col_idx, vec![1, 2]);
/// assert_eq!(val, vec![1.0, 2.0]);
/// ```
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
(self.row_indices, self.col_indices, self.values)
}
}

530
nalgebra-sparse/src/cs.rs Normal file
View File

@ -0,0 +1,530 @@
use std::mem::replace;
use std::ops::Range;
use num_traits::One;
use nalgebra::Scalar;
use crate::pattern::SparsityPattern;
use crate::{SparseEntry, SparseEntryMut};
/// An abstract compressed matrix.
///
/// For the time being, this is only used internally to share implementation between
/// CSR and CSC matrices.
///
/// A CSR matrix is obtained by associating rows with the major dimension, while a CSC matrix
/// is obtained by associating columns with the major dimension.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsMatrix<T> {
sparsity_pattern: SparsityPattern,
values: Vec<T>,
}
impl<T> CsMatrix<T> {
/// Create a zero matrix with no explicitly stored entries.
#[inline]
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
Self {
sparsity_pattern: SparsityPattern::zeros(major_dim, minor_dim),
values: vec![],
}
}
#[inline]
pub fn pattern(&self) -> &SparsityPattern {
&self.sparsity_pattern
}
#[inline]
pub fn values(&self) -> &[T] {
&self.values
}
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
&mut self.values
}
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline]
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
let pattern = self.pattern();
(
pattern.major_offsets(),
pattern.minor_indices(),
&self.values,
)
}
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
#[inline]
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
let pattern = &mut self.sparsity_pattern;
(
pattern.major_offsets(),
pattern.minor_indices(),
&mut self.values,
)
}
#[inline]
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
(&self.sparsity_pattern, &mut self.values)
}
#[inline]
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
assert_eq!(
pattern.nnz(),
values.len(),
"Internal error: consumers should verify shape compatibility."
);
Self {
sparsity_pattern: pattern,
values,
}
}
/// Internal method for simplifying access to a lane's data
#[inline]
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
Some(row_begin..row_end)
}
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
(self.sparsity_pattern, self.values)
}
#[inline]
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
let (offsets, indices) = self.sparsity_pattern.disassemble();
(offsets, indices, self.values)
}
#[inline]
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
(self.sparsity_pattern, self.values)
}
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
/// of bounds.
pub fn get_entry(&self, major_index: usize, minor_index: usize) -> Option<SparseEntry<T>> {
let row_range = self.get_index_range(major_index)?;
let (_, minor_indices, values) = self.cs_data();
let minor_indices = &minor_indices[row_range.clone()];
let values = &values[row_range];
get_entry_from_slices(
self.pattern().minor_dim(),
minor_indices,
values,
minor_index,
)
}
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
/// of bounds.
pub fn get_entry_mut(
&mut self,
major_index: usize,
minor_index: usize,
) -> Option<SparseEntryMut<T>> {
let row_range = self.get_index_range(major_index)?;
let minor_dim = self.pattern().minor_dim();
let (_, minor_indices, values) = self.cs_data_mut();
let minor_indices = &minor_indices[row_range.clone()];
let values = &mut values[row_range];
get_mut_entry_from_slices(minor_dim, minor_indices, values, minor_index)
}
pub fn get_lane(&self, index: usize) -> Option<CsLane<T>> {
let range = self.get_index_range(index)?;
let (_, minor_indices, values) = self.cs_data();
Some(CsLane {
minor_indices: &minor_indices[range.clone()],
values: &values[range],
minor_dim: self.pattern().minor_dim(),
})
}
#[inline]
pub fn get_lane_mut(&mut self, index: usize) -> Option<CsLaneMut<T>> {
let range = self.get_index_range(index)?;
let minor_dim = self.pattern().minor_dim();
let (_, minor_indices, values) = self.cs_data_mut();
Some(CsLaneMut {
minor_dim,
minor_indices: &minor_indices[range.clone()],
values: &mut values[range],
})
}
#[inline]
pub fn lane_iter(&self) -> CsLaneIter<T> {
CsLaneIter::new(self.pattern(), self.values())
}
#[inline]
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
CsLaneIterMut::new(&self.sparsity_pattern, &mut self.values)
}
#[inline]
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool,
{
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
let mut new_indices = Vec::new();
let mut new_values = Vec::new();
new_offsets.push(0);
for (i, lane) in self.lane_iter().enumerate() {
for (&j, value) in lane.minor_indices().iter().zip(lane.values) {
if predicate(i, j, value) {
new_indices.push(j);
new_values.push(value.clone());
}
}
new_offsets.push(new_indices.len());
}
// TODO: Avoid checks here
let new_pattern = SparsityPattern::try_from_offsets_and_indices(
major_dim,
minor_dim,
new_offsets,
new_indices,
)
.expect("Internal error: Sparsity pattern must always be valid.");
Self::from_pattern_and_values(new_pattern, new_values)
}
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_matrix(&self) -> Self
where
T: Clone,
{
// TODO: This might be faster with a binary search for each diagonal entry
self.filter(|i, j, _| i == j)
}
}
impl<T: Scalar + One> CsMatrix<T> {
#[inline]
pub fn identity(n: usize) -> Self {
let offsets: Vec<_> = (0..=n).collect();
let indices: Vec<_> = (0..n).collect();
let values = vec![T::one(); n];
// TODO: We should skip checks here
let pattern =
SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
Self::from_pattern_and_values(pattern, values)
}
}
fn get_entry_from_slices<'a, T>(
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a [T],
global_minor_index: usize,
) -> Option<SparseEntry<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_index);
if let Ok(local_index) = local_index {
Some(SparseEntry::NonZero(&values[local_index]))
} else if global_minor_index < minor_dim {
Some(SparseEntry::Zero)
} else {
None
}
}
fn get_mut_entry_from_slices<'a, T>(
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a mut [T],
global_minor_indices: usize,
) -> Option<SparseEntryMut<'a, T>> {
let local_index = minor_indices.binary_search(&global_minor_indices);
if let Ok(local_index) = local_index {
Some(SparseEntryMut::NonZero(&mut values[local_index]))
} else if global_minor_indices < minor_dim {
Some(SparseEntryMut::Zero)
} else {
None
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsLane<'a, T> {
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a [T],
}
#[derive(Debug, PartialEq, Eq)]
pub struct CsLaneMut<'a, T> {
minor_dim: usize,
minor_indices: &'a [usize],
values: &'a mut [T],
}
pub struct CsLaneIter<'a, T> {
// The index of the lane that will be returned on the next iteration
current_lane_idx: usize,
pattern: &'a SparsityPattern,
remaining_values: &'a [T],
}
impl<'a, T> CsLaneIter<'a, T> {
pub fn new(pattern: &'a SparsityPattern, values: &'a [T]) -> Self {
Self {
current_lane_idx: 0,
pattern,
remaining_values: values,
}
}
}
impl<'a, T> Iterator for CsLaneIter<'a, T>
where
T: 'a,
{
type Item = CsLane<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
let lane = self.pattern.get_lane(self.current_lane_idx);
let minor_dim = self.pattern.minor_dim();
if let Some(minor_indices) = lane {
let count = minor_indices.len();
let values_in_lane = &self.remaining_values[..count];
self.remaining_values = &self.remaining_values[count..];
self.current_lane_idx += 1;
Some(CsLane {
minor_dim,
minor_indices,
values: values_in_lane,
})
} else {
None
}
}
}
pub struct CsLaneIterMut<'a, T> {
// The index of the lane that will be returned on the next iteration
current_lane_idx: usize,
pattern: &'a SparsityPattern,
remaining_values: &'a mut [T],
}
impl<'a, T> CsLaneIterMut<'a, T> {
pub fn new(pattern: &'a SparsityPattern, values: &'a mut [T]) -> Self {
Self {
current_lane_idx: 0,
pattern,
remaining_values: values,
}
}
}
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
where
T: 'a,
{
type Item = CsLaneMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
let lane = self.pattern.get_lane(self.current_lane_idx);
let minor_dim = self.pattern.minor_dim();
if let Some(minor_indices) = lane {
let count = minor_indices.len();
let remaining = replace(&mut self.remaining_values, &mut []);
let (values_in_lane, remaining) = remaining.split_at_mut(count);
self.remaining_values = remaining;
self.current_lane_idx += 1;
Some(CsLaneMut {
minor_dim,
minor_indices,
values: values_in_lane,
})
} else {
None
}
}
}
/// Implement the methods common to both CsLane and CsLaneMut. See the documentation for the
/// methods delegated here by CsrMatrix and CscMatrix members for more information.
macro_rules! impl_cs_lane_common_methods {
($name:ty) => {
impl<'a, T> $name {
#[inline]
pub fn minor_dim(&self) -> usize {
self.minor_dim
}
#[inline]
pub fn nnz(&self) -> usize {
self.minor_indices.len()
}
#[inline]
pub fn minor_indices(&self) -> &[usize] {
self.minor_indices
}
#[inline]
pub fn values(&self) -> &[T] {
self.values
}
#[inline]
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
get_entry_from_slices(
self.minor_dim,
self.minor_indices,
self.values,
global_col_index,
)
}
}
};
}
impl_cs_lane_common_methods!(CsLane<'a, T>);
impl_cs_lane_common_methods!(CsLaneMut<'a, T>);
impl<'a, T> CsLaneMut<'a, T> {
pub fn values_mut(&mut self) -> &mut [T] {
self.values
}
pub fn indices_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
(self.minor_indices, self.values)
}
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
get_mut_entry_from_slices(
self.minor_dim,
self.minor_indices,
self.values,
global_minor_index,
)
}
}
/// Helper struct for working with uninitialized data in vectors.
/// TODO: This doesn't belong here.
struct UninitVec<T> {
vec: Vec<T>,
len: usize,
}
impl<T> UninitVec<T> {
pub fn from_len(len: usize) -> Self {
Self {
vec: Vec::with_capacity(len),
// We need to store len separately, because for zero-sized types,
// Vec::with_capacity(len) does not give vec.capacity() == len
len,
}
}
/// Sets the element associated with the given index to the provided value.
///
/// Must be called exactly once per index, otherwise results in undefined behavior.
pub unsafe fn set(&mut self, index: usize, value: T) {
self.vec.as_mut_ptr().add(index).write(value)
}
/// Marks the vector data as initialized by returning a full vector.
///
/// It is undefined behavior to call this function unless *all* elements have been written to
/// exactly once.
pub unsafe fn assume_init(mut self) -> Vec<T> {
self.vec.set_len(self.len);
self.vec
}
}
/// Transposes the compressed format.
///
/// This means that major and minor roles are switched. This is used for converting between CSR
/// and CSC formats.
pub fn transpose_cs<T>(
major_dim: usize,
minor_dim: usize,
source_major_offsets: &[usize],
source_minor_indices: &[usize],
values: &[T],
) -> (Vec<usize>, Vec<usize>, Vec<T>)
where
T: Scalar,
{
assert_eq!(source_major_offsets.len(), major_dim + 1);
assert_eq!(source_minor_indices.len(), values.len());
let nnz = values.len();
// Count the number of occurences of each minor index
let mut minor_counts = vec![0; minor_dim];
for minor_idx in source_minor_indices {
minor_counts[*minor_idx] += 1;
}
convert_counts_to_offsets(&mut minor_counts);
let mut target_offsets = minor_counts;
target_offsets.push(nnz);
let mut target_indices = vec![usize::MAX; nnz];
// We have to use uninitialized storage, because we don't have any kind of "default" value
// available for `T`. Unfortunately this necessitates some small amount of unsafe code
let mut target_values = UninitVec::from_len(nnz);
// Keep track of how many entries we have placed in each target major lane
let mut current_target_major_counts = vec![0; minor_dim];
for source_major_idx in 0..major_dim {
let source_lane_begin = source_major_offsets[source_major_idx];
let source_lane_end = source_major_offsets[source_major_idx + 1];
let source_lane_indices = &source_minor_indices[source_lane_begin..source_lane_end];
let source_lane_values = &values[source_lane_begin..source_lane_end];
for (&source_minor_idx, val) in source_lane_indices.iter().zip(source_lane_values) {
// Compute the offset in the target data for this particular source entry
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
target_indices[entry_offset] = source_major_idx;
unsafe {
target_values.set(entry_offset, val.inlined_clone());
}
*target_lane_count += 1;
}
}
// At this point, we should have written to each element in target_values exactly once,
// so initialization should be sound
let target_values = unsafe { target_values.assume_init() };
(target_offsets, target_indices, target_values)
}
pub fn convert_counts_to_offsets(counts: &mut [usize]) {
// Convert the counts to an offset
let mut offset = 0;
for i_offset in counts.iter_mut() {
let count = *i_offset;
*i_offset = offset;
offset += count;
}
}

704
nalgebra-sparse/src/csc.rs Normal file
View File

@ -0,0 +1,704 @@
//! An implementation of the CSC sparse matrix format.
//!
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
//! CSC implementation.
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csr::CsrMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar;
use num_traits::One;
use std::slice::{Iter, IterMut};
/// A CSC representation of a sparse matrix.
///
/// The Compressed Sparse Column (CSC) format is well-suited as a general-purpose storage format
/// for many sparse matrix applications.
///
/// # Usage
///
/// ```rust
/// use nalgebra_sparse::csc::CscMatrix;
/// use nalgebra::{DMatrix, Matrix3x4};
/// use matrixcompare::assert_matrix_eq;
///
/// // The sparsity patterns of CSC matrices are immutable. This means that you cannot dynamically
/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
/// // way to construct a CSC matrix is to first incrementally construct a COO matrix,
/// // and then convert it to CSC.
/// # use nalgebra_sparse::coo::CooMatrix;
/// # let coo = CooMatrix::<f64>::new(3, 3);
/// let csc = CscMatrix::from(&coo);
///
/// // Alternatively, a CSC matrix can be constructed directly from raw CSC data.
/// // Here, we construct a 3x4 matrix
/// let col_offsets = vec![0, 1, 3, 4, 5];
/// let row_indices = vec![0, 0, 2, 2, 0];
/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
///
/// // The dense representation of the CSC data, for comparison
/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 5.0,
/// 0.0, 0.0, 0.0, 0.0,
/// 0.0, 3.0, 4.0, 0.0);
///
/// // The constructor validates the raw CSC data and returns an error if it is invalid.
/// let csc = CscMatrix::try_from_csc_data(3, 4, col_offsets, row_indices, values)
/// .expect("CSC data must conform to format specifications");
/// assert_matrix_eq!(csc, dense);
///
/// // A third approach is to construct a CSC matrix from a pattern and values. Sometimes this is
/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
/// let (pattern, values) = csc.into_pattern_and_values();
/// let csc = CscMatrix::try_from_pattern_and_values(pattern, values)
/// .expect("The pattern and values must be compatible");
///
/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
/// // other CSC matrices and dense matrices/vectors.
/// let x = csc;
/// # #[allow(non_snake_case)]
/// let xTx = x.transpose() * &x;
/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
/// let w = 3.0 * xTx * z;
///
/// // Although the sparsity pattern of a CSC matrix cannot be changed, its values can.
/// // Here are two different ways to scale all values by a constant:
/// let mut x = x;
/// x *= 5.0;
/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
/// ```
///
/// # Format
///
/// An `m x n` sparse matrix with `nnz` non-zeros in CSC format is represented by the
/// following three arrays:
///
/// - `col_offsets`, an array of integers with length `n + 1`.
/// - `row_indices`, an array of integers with length `nnz`.
/// - `values`, an array of values with length `nnz`.
///
/// The relationship between the arrays is described below.
///
/// - Each consecutive pair of entries `col_offsets[j] .. col_offsets[j + 1]` corresponds to an
/// offset range in `row_indices` that holds the row indices in column `j`.
/// - For an entry represented by the index `idx`, `row_indices[idx]` stores its column index and
/// `values[idx]` stores its value.
///
/// The following invariants must be upheld and are enforced by the data structure:
///
/// - `col_offsets[0] == 0`
/// - `col_offsets[m] == nnz`
/// - `col_offsets` is monotonically increasing.
/// - `0 <= row_indices[idx] < m` for all `idx < nnz`.
/// - The row indices associated with each column are monotonically increasing (see below).
///
/// The CSC format is a standard sparse matrix format (see [Wikipedia article]). The format
/// represents the matrix in a column-by-column fashion. The entries associated with column `j` are
/// determined as follows:
///
/// ```rust
/// # let col_offsets: Vec<usize> = vec![0, 0];
/// # let row_indices: Vec<usize> = vec![];
/// # let values: Vec<i32> = vec![];
/// # let j = 0;
/// let range = col_offsets[j] .. col_offsets[j + 1];
/// let col_j_rows = &row_indices[range.clone()];
/// let col_j_vals = &values[range];
///
/// // For each pair (i, v) in (col_j_rows, col_j_vals), we obtain a corresponding entry
/// // (i, j, v) in the matrix.
/// assert_eq!(col_j_rows.len(), col_j_vals.len());
/// ```
///
/// In the above example, for each column `j`, the row indices `col_j_cols` must appear in
/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
/// assumption for both correctness and performance for many algorithms.
///
/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
/// column-by-column instead of row-by-row like CSR.
///
/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscMatrix<T> {
// Cols are major, rows are minor in the sparsity pattern
pub(crate) cs: CsMatrix<T>,
}
impl<T> CscMatrix<T> {
/// Constructs a CSC representation of the (square) `n x n` identity matrix.
#[inline]
pub fn identity(n: usize) -> Self
where
T: Scalar + One,
{
Self {
cs: CsMatrix::identity(n),
}
}
/// Create a zero CSC matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
cs: CsMatrix::new(ncols, nrows),
}
}
/// Try to construct a CSC matrix from raw CSC data.
///
/// It is assumed that each column contains unique and sorted row indices that are in
/// bounds with respect to the number of rows in the matrix. If this is not the case,
/// an error is returned to indicate the failure.
///
/// An error is returned if the data given does not conform to the CSC storage format.
/// See the documentation for [CscMatrix](struct.CscMatrix.html) for more information.
pub fn try_from_csc_data(
num_rows: usize,
num_cols: usize,
col_offsets: Vec<usize>,
row_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_cols,
num_rows,
col_offsets,
row_indices,
)
.map_err(pattern_format_error_to_csc_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and row indices must be the same",
))
}
}
/// The number of rows in the matrix.
#[inline]
pub fn nrows(&self) -> usize {
self.cs.pattern().minor_dim()
}
/// The number of columns in the matrix.
#[inline]
pub fn ncols(&self) -> usize {
self.cs.pattern().major_dim()
}
/// The number of non-zeros in the matrix.
///
/// Note that this corresponds to the number of explicitly stored entries, *not* the actual
/// number of algebraically zero entries in the matrix. Explicitly stored entries can still
/// be zero. Corresponds to the number of entries in the sparsity pattern.
#[inline]
pub fn nnz(&self) -> usize {
self.pattern().nnz()
}
/// The column offsets defining part of the CSC format.
#[inline]
pub fn col_offsets(&self) -> &[usize] {
self.pattern().major_offsets()
}
/// The row indices defining part of the CSC format.
#[inline]
pub fn row_indices(&self) -> &[usize] {
self.pattern().minor_indices()
}
/// The non-zero values defining part of the CSC format.
#[inline]
pub fn values(&self) -> &[T] {
self.cs.values()
}
/// Mutable access to the non-zero values.
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
self.cs.values_mut()
}
/// An iterator over non-zero triplets (i, j, v).
///
/// The iteration happens in column-major fashion, meaning that j increases monotonically,
/// and i increases monotonically within each row.
///
/// Examples
/// --------
/// ```
/// # use nalgebra_sparse::csc::CscMatrix;
/// let col_offsets = vec![0, 2, 3, 4];
/// let row_indices = vec![0, 2, 1, 0];
/// let values = vec![1, 3, 2, 4];
/// let mut csc = CscMatrix::try_from_csc_data(4, 3, col_offsets, row_indices, values)
/// .unwrap();
///
/// let triplets: Vec<_> = csc.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 3), (1, 1, 2), (0, 2, 4)]);
/// ```
pub fn triplet_iter(&self) -> CscTripletIter<T> {
CscTripletIter {
pattern_iter: self.pattern().entries(),
values_iter: self.values().iter(),
}
}
/// A mutable iterator over non-zero triplets (i, j, v).
///
/// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
///
/// Examples
/// --------
/// ```
/// # use nalgebra_sparse::csc::CscMatrix;
/// let col_offsets = vec![0, 2, 3, 4];
/// let row_indices = vec![0, 2, 1, 0];
/// let values = vec![1, 3, 2, 4];
/// // Using the same data as in the `triplet_iter` example
/// let mut csc = CscMatrix::try_from_csc_data(4, 3, col_offsets, row_indices, values)
/// .unwrap();
///
/// // Zero out lower-triangular terms
/// csc.triplet_iter_mut()
/// .filter(|(i, j, _)| j < i)
/// .for_each(|(_, _, v)| *v = 0);
///
/// let triplets: Vec<_> = csc.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]);
/// ```
pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscTripletIterMut {
pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut(),
}
}
/// Return the column at the given column index.
///
/// Panics
/// ------
/// Panics if column index is out of bounds.
#[inline]
pub fn col(&self, index: usize) -> CscCol<T> {
self.get_col(index).expect("Row index must be in bounds")
}
/// Mutable column access for the given column index.
///
/// Panics
/// ------
/// Panics if column index is out of bounds.
#[inline]
pub fn col_mut(&mut self, index: usize) -> CscColMut<T> {
self.get_col_mut(index)
.expect("Row index must be in bounds")
}
/// Return the column at the given column index, or `None` if out of bounds.
#[inline]
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
self.cs.get_lane(index).map(|lane| CscCol { lane })
}
/// Mutable column access for the given column index, or `None` if out of bounds.
#[inline]
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
}
/// An iterator over columns in the matrix.
pub fn col_iter(&self) -> CscColIter<T> {
CscColIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
}
}
/// A mutable iterator over columns in the matrix.
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CscColIterMut {
lane_iter: CsLaneIterMut::new(pattern, values),
}
}
/// Disassembles the CSC matrix into its underlying offset, index and value arrays.
///
/// If the matrix contains the sole reference to the sparsity pattern,
/// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
///
/// Examples
/// --------
///
/// ```
/// # use nalgebra_sparse::csc::CscMatrix;
/// let col_offsets = vec![0, 2, 3, 4];
/// let row_indices = vec![0, 2, 1, 0];
/// let values = vec![1, 3, 2, 4];
/// let mut csc = CscMatrix::try_from_csc_data(
/// 4,
/// 3,
/// col_offsets.clone(),
/// row_indices.clone(),
/// values.clone())
/// .unwrap();
/// let (col_offsets2, row_indices2, values2) = csc.disassemble();
/// assert_eq!(col_offsets2, col_offsets);
/// assert_eq!(row_indices2, row_indices);
/// assert_eq!(values2, values);
/// ```
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
self.cs.disassemble()
}
/// Returns the sparsity pattern and values associated with this matrix.
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
self.cs.into_pattern_and_values()
}
/// Returns a reference to the sparsity pattern and a mutable reference to the values.
#[inline]
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
self.cs.pattern_and_values_mut()
}
/// Returns a reference to the underlying sparsity pattern.
pub fn pattern(&self) -> &SparsityPattern {
self.cs.pattern()
}
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
///
/// This operation does not touch the CSC data, and is effectively a no-op.
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
let (pattern, values) = self.cs.take_pattern_and_values();
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
}
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored row entries for the given column.
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
self.cs.get_entry(col_index, row_index)
}
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
/// of bounds.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored row entries for the given column.
pub fn get_entry_mut(
&mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<T>> {
self.cs.get_entry_mut(col_index, row_index)
}
/// Returns an entry for the given row/col indices.
///
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
/// out of bounds.
///
/// Panics
/// ------
/// Panics if `row_index` or `col_index` is out of bounds.
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
self.get_entry(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
/// Returns a mutable entry for the given row/col indices.
///
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
/// out of bounds.
///
/// Panics
/// ------
/// Panics if `row_index` or `col_index` is out of bounds.
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
self.get_entry_mut(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSC data.
pub fn csc_data(&self) -> (&[usize], &[usize], &[T]) {
self.cs.cs_data()
}
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSC data,
/// where the `values` array is mutable.
pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
self.cs.cs_data_mut()
}
/// Creates a sparse matrix that contains only the explicit entries decided by the
/// given predicate.
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool,
{
// Note: Predicate uses (row, col, value), so we have to switch around since
// cs uses (major, minor, value)
Self {
cs: self
.cs
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
}
}
/// Returns a new matrix representing the upper triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i <= j)
}
/// Returns a new matrix representing the lower triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i >= j)
}
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_csc(&self) -> Self
where
T: Clone,
{
Self {
cs: self.cs.diagonal_as_matrix(),
}
}
/// Compute the transpose of the matrix.
pub fn transpose(&self) -> CscMatrix<T>
where
T: Scalar,
{
CsrMatrix::from(self).transpose_as_csc()
}
}
/// Convert pattern format errors into more meaningful CSC-specific errors.
///
/// This ensures that the terminology is consistent: we are talking about rows and columns,
/// not lanes, major and minor dimensions.
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparseFormatError as E;
use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of col offset array is not equal to ncols + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure,
"First or last col offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure,
"Col offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure,
"Row indices are not monotonically increasing (sorted) within each column.",
),
MinorIndexOutOfBounds => {
E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
}
PatternDuplicateEntry => {
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
}
}
/// Iterator type for iterating over triplets in a CSC matrix.
#[derive(Debug)]
pub struct CscTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T>,
}
impl<'a, T: Clone> CscTripletIter<'a, T> {
/// Adapts the triplet iterator to return owned values.
///
/// The triplet iterator returns references to the values. This method adapts the iterator
/// so that the values are cloned.
#[inline]
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
self.map(|(i, j, v)| (i, j, v.clone()))
}
}
impl<'a, T> Iterator for CscTripletIter<'a, T> {
type Item = (usize, usize, &'a T);
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None,
}
}
}
/// Iterator type for mutably iterating over triplets in a CSC matrix.
#[derive(Debug)]
pub struct CscTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T>,
}
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
type Item = (usize, usize, &'a mut T);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_mut_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((j, i, v)),
_ => None,
}
}
}
/// An immutable representation of a column in a CSC matrix.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscCol<'a, T> {
lane: CsLane<'a, T>,
}
/// A mutable representation of a column in a CSC matrix.
///
/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
/// to the column cannot be modified.
#[derive(Debug, PartialEq, Eq)]
pub struct CscColMut<'a, T> {
lane: CsLaneMut<'a, T>,
}
/// Implement the methods common to both CscCol and CscColMut
macro_rules! impl_csc_col_common_methods {
($name:ty) => {
impl<'a, T> $name {
/// The number of global rows in the column.
#[inline]
pub fn nrows(&self) -> usize {
self.lane.minor_dim()
}
/// The number of non-zeros in this column.
#[inline]
pub fn nnz(&self) -> usize {
self.lane.nnz()
}
/// The row indices corresponding to explicitly stored entries in this column.
#[inline]
pub fn row_indices(&self) -> &[usize] {
self.lane.minor_indices()
}
/// The values corresponding to explicitly stored entries in this column.
#[inline]
pub fn values(&self) -> &[T] {
self.lane.values()
}
/// Returns an entry for the given global row index.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored row entries.
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<T>> {
self.lane.get_entry(global_row_index)
}
}
};
}
impl_csc_col_common_methods!(CscCol<'a, T>);
impl_csc_col_common_methods!(CscColMut<'a, T>);
impl<'a, T> CscColMut<'a, T> {
/// Mutable access to the values corresponding to explicitly stored entries in this column.
pub fn values_mut(&mut self) -> &mut [T] {
self.lane.values_mut()
}
/// Provides simultaneous access to row indices and mutable values corresponding to the
/// explicitly stored entries in this column.
///
/// This method primarily facilitates low-level access for methods that process data stored
/// in CSC format directly.
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
self.lane.indices_and_values_mut()
}
/// Returns a mutable entry for the given global row index.
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<T>> {
self.lane.get_entry_mut(global_row_index)
}
}
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIter<'a, T> {
lane_iter: CsLaneIter<'a, T>,
}
impl<'a, T> Iterator for CscColIter<'a, T> {
type Item = CscCol<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CscCol { lane })
}
}
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
pub struct CscColIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T>,
}
impl<'a, T> Iterator for CscColIterMut<'a, T>
where
T: 'a,
{
type Item = CscColMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CscColMut { lane })
}
}

708
nalgebra-sparse/src/csr.rs Normal file
View File

@ -0,0 +1,708 @@
//! An implementation of the CSR sparse matrix format.
//!
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
//! CSC implementation.
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
use crate::csc::CscMatrix;
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
use nalgebra::Scalar;
use num_traits::One;
use std::slice::{Iter, IterMut};
/// A CSR representation of a sparse matrix.
///
/// The Compressed Sparse Row (CSR) format is well-suited as a general-purpose storage format
/// for many sparse matrix applications.
///
/// # Usage
///
/// ```rust
/// use nalgebra_sparse::csr::CsrMatrix;
/// use nalgebra::{DMatrix, Matrix3x4};
/// use matrixcompare::assert_matrix_eq;
///
/// // The sparsity patterns of CSR matrices are immutable. This means that you cannot dynamically
/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
/// // way to construct a CSR matrix is to first incrementally construct a COO matrix,
/// // and then convert it to CSR.
/// # use nalgebra_sparse::coo::CooMatrix;
/// # let coo = CooMatrix::<f64>::new(3, 3);
/// let csr = CsrMatrix::from(&coo);
///
/// // Alternatively, a CSR matrix can be constructed directly from raw CSR data.
/// // Here, we construct a 3x4 matrix
/// let row_offsets = vec![0, 3, 3, 5];
/// let col_indices = vec![0, 1, 3, 1, 2];
/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
///
/// // The dense representation of the CSR data, for comparison
/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 3.0,
/// 0.0, 0.0, 0.0, 0.0,
/// 0.0, 4.0, 5.0, 0.0);
///
/// // The constructor validates the raw CSR data and returns an error if it is invalid.
/// let csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
/// .expect("CSR data must conform to format specifications");
/// assert_matrix_eq!(csr, dense);
///
/// // A third approach is to construct a CSR matrix from a pattern and values. Sometimes this is
/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
/// let (pattern, values) = csr.into_pattern_and_values();
/// let csr = CsrMatrix::try_from_pattern_and_values(pattern, values)
/// .expect("The pattern and values must be compatible");
///
/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
/// // other CSR matrices and dense matrices/vectors.
/// let x = csr;
/// # #[allow(non_snake_case)]
/// let xTx = x.transpose() * &x;
/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
/// let w = 3.0 * xTx * z;
///
/// // Although the sparsity pattern of a CSR matrix cannot be changed, its values can.
/// // Here are two different ways to scale all values by a constant:
/// let mut x = x;
/// x *= 5.0;
/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
/// ```
///
/// # Format
///
/// An `m x n` sparse matrix with `nnz` non-zeros in CSR format is represented by the
/// following three arrays:
///
/// - `row_offsets`, an array of integers with length `m + 1`.
/// - `col_indices`, an array of integers with length `nnz`.
/// - `values`, an array of values with length `nnz`.
///
/// The relationship between the arrays is described below.
///
/// - Each consecutive pair of entries `row_offsets[i] .. row_offsets[i + 1]` corresponds to an
/// offset range in `col_indices` that holds the column indices in row `i`.
/// - For an entry represented by the index `idx`, `col_indices[idx]` stores its column index and
/// `values[idx]` stores its value.
///
/// The following invariants must be upheld and are enforced by the data structure:
///
/// - `row_offsets[0] == 0`
/// - `row_offsets[m] == nnz`
/// - `row_offsets` is monotonically increasing.
/// - `0 <= col_indices[idx] < n` for all `idx < nnz`.
/// - The column indices associated with each row are monotonically increasing (see below).
///
/// The CSR format is a standard sparse matrix format (see [Wikipedia article]). The format
/// represents the matrix in a row-by-row fashion. The entries associated with row `i` are
/// determined as follows:
///
/// ```rust
/// # let row_offsets: Vec<usize> = vec![0, 0];
/// # let col_indices: Vec<usize> = vec![];
/// # let values: Vec<i32> = vec![];
/// # let i = 0;
/// let range = row_offsets[i] .. row_offsets[i + 1];
/// let row_i_cols = &col_indices[range.clone()];
/// let row_i_vals = &values[range];
///
/// // For each pair (j, v) in (row_i_cols, row_i_vals), we obtain a corresponding entry
/// // (i, j, v) in the matrix.
/// assert_eq!(row_i_cols.len(), row_i_vals.len());
/// ```
///
/// In the above example, for each row `i`, the column indices `row_i_cols` must appear in
/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
/// assumption for both correctness and performance for many algorithms.
///
/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
/// column-by-column instead of row-by-row like CSR.
///
/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrMatrix<T> {
// Rows are major, cols are minor in the sparsity pattern
pub(crate) cs: CsMatrix<T>,
}
impl<T> CsrMatrix<T> {
/// Constructs a CSR representation of the (square) `n x n` identity matrix.
#[inline]
pub fn identity(n: usize) -> Self
where
T: Scalar + One,
{
Self {
cs: CsMatrix::identity(n),
}
}
/// Create a zero CSR matrix with no explicitly stored entries.
pub fn zeros(nrows: usize, ncols: usize) -> Self {
Self {
cs: CsMatrix::new(nrows, ncols),
}
}
/// Try to construct a CSR matrix from raw CSR data.
///
/// It is assumed that each row contains unique and sorted column indices that are in
/// bounds with respect to the number of columns in the matrix. If this is not the case,
/// an error is returned to indicate the failure.
///
/// An error is returned if the data given does not conform to the CSR storage format.
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
pub fn try_from_csr_data(
num_rows: usize,
num_cols: usize,
row_offsets: Vec<usize>,
col_indices: Vec<usize>,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
let pattern = SparsityPattern::try_from_offsets_and_indices(
num_rows,
num_cols,
row_offsets,
col_indices,
)
.map_err(pattern_format_error_to_csr_error)?;
Self::try_from_pattern_and_values(pattern, values)
}
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
///
/// Returns an error if the number of values does not match the number of minor indices
/// in the pattern.
pub fn try_from_pattern_and_values(
pattern: SparsityPattern,
values: Vec<T>,
) -> Result<Self, SparseFormatError> {
if pattern.nnz() == values.len() {
Ok(Self {
cs: CsMatrix::from_pattern_and_values(pattern, values),
})
} else {
Err(SparseFormatError::from_kind_and_msg(
SparseFormatErrorKind::InvalidStructure,
"Number of values and column indices must be the same",
))
}
}
/// The number of rows in the matrix.
#[inline]
pub fn nrows(&self) -> usize {
self.cs.pattern().major_dim()
}
/// The number of columns in the matrix.
#[inline]
pub fn ncols(&self) -> usize {
self.cs.pattern().minor_dim()
}
/// The number of non-zeros in the matrix.
///
/// Note that this corresponds to the number of explicitly stored entries, *not* the actual
/// number of algebraically zero entries in the matrix. Explicitly stored entries can still
/// be zero. Corresponds to the number of entries in the sparsity pattern.
#[inline]
pub fn nnz(&self) -> usize {
self.cs.pattern().nnz()
}
/// The row offsets defining part of the CSR format.
#[inline]
pub fn row_offsets(&self) -> &[usize] {
let (offsets, _, _) = self.cs.cs_data();
offsets
}
/// The column indices defining part of the CSR format.
#[inline]
pub fn col_indices(&self) -> &[usize] {
let (_, indices, _) = self.cs.cs_data();
indices
}
/// The non-zero values defining part of the CSR format.
#[inline]
pub fn values(&self) -> &[T] {
self.cs.values()
}
/// Mutable access to the non-zero values.
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
self.cs.values_mut()
}
/// An iterator over non-zero triplets (i, j, v).
///
/// The iteration happens in row-major fashion, meaning that i increases monotonically,
/// and j increases monotonically within each row.
///
/// Examples
/// --------
/// ```
/// # use nalgebra_sparse::csr::CsrMatrix;
/// let row_offsets = vec![0, 2, 3, 4];
/// let col_indices = vec![0, 2, 1, 0];
/// let values = vec![1, 2, 3, 4];
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
/// .unwrap();
///
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 4)]);
/// ```
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
CsrTripletIter {
pattern_iter: self.pattern().entries(),
values_iter: self.values().iter(),
}
}
/// A mutable iterator over non-zero triplets (i, j, v).
///
/// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
///
/// Examples
/// --------
/// ```
/// # use nalgebra_sparse::csr::CsrMatrix;
/// # let row_offsets = vec![0, 2, 3, 4];
/// # let col_indices = vec![0, 2, 1, 0];
/// # let values = vec![1, 2, 3, 4];
/// // Using the same data as in the `triplet_iter` example
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
/// .unwrap();
///
/// // Zero out lower-triangular terms
/// csr.triplet_iter_mut()
/// .filter(|(i, j, _)| j < i)
/// .for_each(|(_, _, v)| *v = 0);
///
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
/// ```
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CsrTripletIterMut {
pattern_iter: pattern.entries(),
values_mut_iter: values.iter_mut(),
}
}
/// Return the row at the given row index.
///
/// Panics
/// ------
/// Panics if row index is out of bounds.
#[inline]
pub fn row(&self, index: usize) -> CsrRow<T> {
self.get_row(index).expect("Row index must be in bounds")
}
/// Mutable row access for the given row index.
///
/// Panics
/// ------
/// Panics if row index is out of bounds.
#[inline]
pub fn row_mut(&mut self, index: usize) -> CsrRowMut<T> {
self.get_row_mut(index)
.expect("Row index must be in bounds")
}
/// Return the row at the given row index, or `None` if out of bounds.
#[inline]
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
self.cs.get_lane(index).map(|lane| CsrRow { lane })
}
/// Mutable row access for the given row index, or `None` if out of bounds.
#[inline]
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
}
/// An iterator over rows in the matrix.
pub fn row_iter(&self) -> CsrRowIter<T> {
CsrRowIter {
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
}
}
/// A mutable iterator over rows in the matrix.
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> {
let (pattern, values) = self.cs.pattern_and_values_mut();
CsrRowIterMut {
lane_iter: CsLaneIterMut::new(pattern, values),
}
}
/// Disassembles the CSR matrix into its underlying offset, index and value arrays.
///
/// If the matrix contains the sole reference to the sparsity pattern,
/// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
///
/// Examples
/// --------
///
/// ```
/// # use nalgebra_sparse::csr::CsrMatrix;
/// let row_offsets = vec![0, 2, 3, 4];
/// let col_indices = vec![0, 2, 1, 0];
/// let values = vec![1, 2, 3, 4];
/// let mut csr = CsrMatrix::try_from_csr_data(
/// 3,
/// 4,
/// row_offsets.clone(),
/// col_indices.clone(),
/// values.clone())
/// .unwrap();
/// let (row_offsets2, col_indices2, values2) = csr.disassemble();
/// assert_eq!(row_offsets2, row_offsets);
/// assert_eq!(col_indices2, col_indices);
/// assert_eq!(values2, values);
/// ```
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
self.cs.disassemble()
}
/// Returns the sparsity pattern and values associated with this matrix.
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
self.cs.into_pattern_and_values()
}
/// Returns a reference to the sparsity pattern and a mutable reference to the values.
#[inline]
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
self.cs.pattern_and_values_mut()
}
/// Returns a reference to the underlying sparsity pattern.
pub fn pattern(&self) -> &SparsityPattern {
self.cs.pattern()
}
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
///
/// This operation does not touch the CSR data, and is effectively a no-op.
pub fn transpose_as_csc(self) -> CscMatrix<T> {
let (pattern, values) = self.cs.take_pattern_and_values();
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
}
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row.
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
self.cs.get_entry(row_index, col_index)
}
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
/// of bounds.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries for the given row.
pub fn get_entry_mut(
&mut self,
row_index: usize,
col_index: usize,
) -> Option<SparseEntryMut<T>> {
self.cs.get_entry_mut(row_index, col_index)
}
/// Returns an entry for the given row/col indices.
///
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
/// out of bounds.
///
/// Panics
/// ------
/// Panics if `row_index` or `col_index` is out of bounds.
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
self.get_entry(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
/// Returns a mutable entry for the given row/col indices.
///
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
/// out of bounds.
///
/// Panics
/// ------
/// Panics if `row_index` or `col_index` is out of bounds.
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
self.get_entry_mut(row_index, col_index)
.expect("Out of bounds matrix indices encountered")
}
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
self.cs.cs_data()
}
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
/// where the `values` array is mutable.
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
self.cs.cs_data_mut()
}
/// Creates a sparse matrix that contains only the explicit entries decided by the
/// given predicate.
pub fn filter<P>(&self, predicate: P) -> Self
where
T: Clone,
P: Fn(usize, usize, &T) -> bool,
{
Self {
cs: self
.cs
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
}
}
/// Returns a new matrix representing the upper triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn upper_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i <= j)
}
/// Returns a new matrix representing the lower triangular part of this matrix.
///
/// The result includes the diagonal of the matrix.
pub fn lower_triangle(&self) -> Self
where
T: Clone,
{
self.filter(|i, j, _| i >= j)
}
/// Returns the diagonal of the matrix as a sparse matrix.
pub fn diagonal_as_csr(&self) -> Self
where
T: Clone,
{
Self {
cs: self.cs.diagonal_as_matrix(),
}
}
/// Compute the transpose of the matrix.
pub fn transpose(&self) -> CsrMatrix<T>
where
T: Scalar,
{
CscMatrix::from(self).transpose_as_csr()
}
}
/// Convert pattern format errors into more meaningful CSR-specific errors.
///
/// This ensures that the terminology is consistent: we are talking about rows and columns,
/// not lanes, major and minor dimensions.
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
use SparseFormatError as E;
use SparseFormatErrorKind as K;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength => E::from_kind_and_msg(
K::InvalidStructure,
"Length of row offset array is not equal to nrows + 1.",
),
InvalidOffsetFirstLast => E::from_kind_and_msg(
K::InvalidStructure,
"First or last row offset is inconsistent with format specification.",
),
NonmonotonicOffsets => E::from_kind_and_msg(
K::InvalidStructure,
"Row offsets are not monotonically increasing.",
),
NonmonotonicMinorIndices => E::from_kind_and_msg(
K::InvalidStructure,
"Column indices are not monotonically increasing (sorted) within each row.",
),
MinorIndexOutOfBounds => {
E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
}
PatternDuplicateEntry => {
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
}
}
}
/// Iterator type for iterating over triplets in a CSR matrix.
#[derive(Debug)]
pub struct CsrTripletIter<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_iter: Iter<'a, T>,
}
impl<'a, T: Clone> CsrTripletIter<'a, T> {
/// Adapts the triplet iterator to return owned values.
///
/// The triplet iterator returns references to the values. This method adapts the iterator
/// so that the values are cloned.
#[inline]
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
self.map(|(i, j, v)| (i, j, v.clone()))
}
}
impl<'a, T> Iterator for CsrTripletIter<'a, T> {
type Item = (usize, usize, &'a T);
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None,
}
}
}
/// Iterator type for mutably iterating over triplets in a CSR matrix.
#[derive(Debug)]
pub struct CsrTripletIterMut<'a, T> {
pattern_iter: SparsityPatternIter<'a>,
values_mut_iter: IterMut<'a, T>,
}
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
type Item = (usize, usize, &'a mut T);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let next_entry = self.pattern_iter.next();
let next_value = self.values_mut_iter.next();
match (next_entry, next_value) {
(Some((i, j)), Some(v)) => Some((i, j, v)),
_ => None,
}
}
}
/// An immutable representation of a row in a CSR matrix.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CsrRow<'a, T> {
lane: CsLane<'a, T>,
}
/// A mutable representation of a row in a CSR matrix.
///
/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
/// to the row cannot be modified.
#[derive(Debug, PartialEq, Eq)]
pub struct CsrRowMut<'a, T> {
lane: CsLaneMut<'a, T>,
}
/// Implement the methods common to both CsrRow and CsrRowMut
macro_rules! impl_csr_row_common_methods {
($name:ty) => {
impl<'a, T> $name {
/// The number of global columns in the row.
#[inline]
pub fn ncols(&self) -> usize {
self.lane.minor_dim()
}
/// The number of non-zeros in this row.
#[inline]
pub fn nnz(&self) -> usize {
self.lane.nnz()
}
/// The column indices corresponding to explicitly stored entries in this row.
#[inline]
pub fn col_indices(&self) -> &[usize] {
self.lane.minor_indices()
}
/// The values corresponding to explicitly stored entries in this row.
#[inline]
pub fn values(&self) -> &[T] {
self.lane.values()
}
/// Returns an entry for the given global column index.
///
/// Each call to this function incurs the cost of a binary search among the explicitly
/// stored column entries.
#[inline]
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
self.lane.get_entry(global_col_index)
}
}
};
}
impl_csr_row_common_methods!(CsrRow<'a, T>);
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
impl<'a, T> CsrRowMut<'a, T> {
/// Mutable access to the values corresponding to explicitly stored entries in this row.
#[inline]
pub fn values_mut(&mut self) -> &mut [T] {
self.lane.values_mut()
}
/// Provides simultaneous access to column indices and mutable values corresponding to the
/// explicitly stored entries in this row.
///
/// This method primarily facilitates low-level access for methods that process data stored
/// in CSR format directly.
#[inline]
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
self.lane.indices_and_values_mut()
}
/// Returns a mutable entry for the given global column index.
#[inline]
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
self.lane.get_entry_mut(global_col_index)
}
}
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIter<'a, T> {
lane_iter: CsLaneIter<'a, T>,
}
impl<'a, T> Iterator for CsrRowIter<'a, T> {
type Item = CsrRow<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CsrRow { lane })
}
}
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
pub struct CsrRowIterMut<'a, T> {
lane_iter: CsLaneIterMut<'a, T>,
}
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
where
T: 'a,
{
type Item = CsrRowMut<'a, T>;
fn next(&mut self) -> Option<Self::Item> {
self.lane_iter.next().map(|lane| CsrRowMut { lane })
}
}

View File

@ -0,0 +1,373 @@
use crate::csc::CscMatrix;
use crate::ops::serial::spsolve_csc_lower_triangular;
use crate::ops::Op;
use crate::pattern::SparsityPattern;
use core::{iter, mem};
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
use std::fmt::{Display, Formatter};
/// A symbolic sparse Cholesky factorization of a CSC matrix.
///
/// The symbolic factorization computes the sparsity pattern of `L`, the Cholesky factor.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscSymbolicCholesky {
// Pattern of the original matrix that was decomposed
m_pattern: SparsityPattern,
l_pattern: SparsityPattern,
// u in this context is L^T, so that M = L L^T
u_pattern: SparsityPattern,
}
impl CscSymbolicCholesky {
/// Compute the symbolic factorization for a sparsity pattern belonging to a CSC matrix.
///
/// The sparsity pattern must be symmetric. However, this is not enforced, and it is the
/// responsibility of the user to ensure that this property holds.
///
/// # Panics
///
/// Panics if the sparsity pattern is not square.
pub fn factor(pattern: SparsityPattern) -> Self {
assert_eq!(
pattern.major_dim(),
pattern.minor_dim(),
"Major and minor dimensions must be the same (square matrix)."
);
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
Self {
m_pattern: pattern,
l_pattern,
u_pattern,
}
}
/// The pattern of the Cholesky factor `L`.
pub fn l_pattern(&self) -> &SparsityPattern {
&self.l_pattern
}
}
/// A sparse Cholesky factorization `A = L L^T` of a [`CscMatrix`].
///
/// The factor `L` is a sparse, lower-triangular matrix. See the article on [Wikipedia] for
/// more information.
///
/// The implementation is a port of the `CsCholesky` implementation in `nalgebra`. It is similar
/// to Tim Davis' [`CSparse`]. The current implementation performs no fill-in reduction, and can
/// therefore be expected to produce much too dense Cholesky factors for many matrices.
/// It is therefore not currently recommended to use this implementation for serious projects.
///
/// [`CSparse`]: https://epubs.siam.org/doi/book/10.1137/1.9780898718881
/// [Wikipedia]: https://en.wikipedia.org/wiki/Cholesky_decomposition
// TODO: We should probably implement PartialEq/Eq, but in that case we'd probably need a
// custom implementation, due to the need to exclude the workspace arrays
#[derive(Debug, Clone)]
pub struct CscCholesky<T> {
// Pattern of the original matrix
m_pattern: SparsityPattern,
l_factor: CscMatrix<T>,
u_pattern: SparsityPattern,
work_x: Vec<T>,
work_c: Vec<usize>,
}
#[derive(Debug, PartialEq, Eq, Clone)]
#[non_exhaustive]
/// Possible errors produced by the Cholesky factorization.
pub enum CholeskyError {
/// The matrix is not positive definite.
NotPositiveDefinite,
}
impl Display for CholeskyError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Matrix is not positive definite")
}
}
impl std::error::Error for CholeskyError {}
impl<T: RealField> CscCholesky<T> {
/// Computes the numerical Cholesky factorization associated with the given
/// symbolic factorization and the provided values.
///
/// The values correspond to the non-zero values of the CSC matrix for which the
/// symbolic factorization was computed.
///
/// # Errors
///
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
/// symmetric positive definite.
///
/// # Panics
///
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
/// of the matrix that was symbolically factored.
pub fn factor_numerical(
symbolic: CscSymbolicCholesky,
values: &[T],
) -> Result<Self, CholeskyError> {
assert_eq!(
symbolic.l_pattern.nnz(),
symbolic.u_pattern.nnz(),
"u is just the transpose of l, so should have the same nnz"
);
let l_nnz = symbolic.l_pattern.nnz();
let l_values = vec![T::zero(); l_nnz];
let l_factor =
CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values).unwrap();
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
let mut factorization = CscCholesky {
m_pattern: symbolic.m_pattern,
l_factor,
u_pattern: symbolic.u_pattern,
work_x: vec![T::zero(); nrows],
// Fill with MAX so that things hopefully totally fail if values are not
// overwritten. Might be easier to debug this way
work_c: vec![usize::MAX, ncols],
};
factorization.refactor(values)?;
Ok(factorization)
}
/// Computes the Cholesky factorization of the provided matrix.
///
/// The matrix must be symmetric positive definite. Symmetry is not checked, and it is up
/// to the user to enforce this property.
///
/// # Errors
///
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
/// symmetric positive definite.
///
/// # Panics
///
/// Panics if the matrix is not square.
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
let symbolic = CscSymbolicCholesky::factor(matrix.pattern().clone());
Self::factor_numerical(symbolic, matrix.values())
}
/// Re-computes the factorization for a new set of non-zero values.
///
/// This is useful when the values of a matrix changes, but the sparsity pattern remains
/// constant.
///
/// # Errors
///
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
/// symmetric positive definite.
///
/// # Panics
///
/// Panics if the number of values does not match the number of non-zeros in the sparsity
/// pattern.
pub fn refactor(&mut self, values: &[T]) -> Result<(), CholeskyError> {
self.decompose_left_looking(values)
}
/// Returns a reference to the Cholesky factor `L`.
pub fn l(&self) -> &CscMatrix<T> {
&self.l_factor
}
/// Returns the Cholesky factor `L`.
pub fn take_l(self) -> CscMatrix<T> {
self.l_factor
}
/// Perform a numerical left-looking cholesky decomposition of a matrix with the same structure as the
/// one used to initialize `self`, but with different non-zero values provided by `values`.
fn decompose_left_looking(&mut self, values: &[T]) -> Result<(), CholeskyError> {
assert!(
values.len() >= self.m_pattern.nnz(),
// TODO: Improve error message
"The set of values is too small."
);
let n = self.l_factor.nrows();
// Reset `work_c` to the column pointers of `l`.
self.work_c.clear();
self.work_c.extend_from_slice(self.l_factor.col_offsets());
unsafe {
for k in 0..n {
// Scatter the k-th column of the original matrix with the values provided.
let range_begin = *self.m_pattern.major_offsets().get_unchecked(k);
let range_end = *self.m_pattern.major_offsets().get_unchecked(k + 1);
let range_k = range_begin..range_end;
*self.work_x.get_unchecked_mut(k) = T::zero();
for p in range_k.clone() {
let irow = *self.m_pattern.minor_indices().get_unchecked(p);
if irow >= k {
*self.work_x.get_unchecked_mut(irow) = *values.get_unchecked(p);
}
}
for &j in self.u_pattern.lane(k) {
let factor = -*self
.l_factor
.values()
.get_unchecked(*self.work_c.get_unchecked(j));
*self.work_c.get_unchecked_mut(j) += 1;
if j < k {
let col_j = self.l_factor.col(j);
let col_j_entries = col_j.row_indices().iter().zip(col_j.values());
for (&z, val) in col_j_entries {
if z >= k {
*self.work_x.get_unchecked_mut(z) += val.inlined_clone() * factor;
}
}
}
}
let diag = *self.work_x.get_unchecked(k);
if diag > T::zero() {
let denom = diag.sqrt();
{
let (offsets, _, values) = self.l_factor.csc_data_mut();
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
}
let mut col_k = self.l_factor.col_mut(k);
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
let col_k_entries = col_k_rows.iter().zip(col_k_values);
for (&p, val) in col_k_entries {
*val = *self.work_x.get_unchecked(p) / denom;
*self.work_x.get_unchecked_mut(p) = T::zero();
}
} else {
return Err(CholeskyError::NotPositiveDefinite);
}
}
}
Ok(())
}
/// Solves the system `A X = B`, where `X` and `B` are dense matrices.
///
/// # Panics
///
/// Panics if `B` is not square.
pub fn solve<'a>(&'a self, b: impl Into<DMatrixSlice<'a, T>>) -> DMatrix<T> {
let b = b.into();
let mut output = b.clone_owned();
self.solve_mut(&mut output);
output
}
/// Solves the system `AX = B`, where `X` and `B` are dense matrices.
///
/// The result is stored in-place in `b`.
///
/// # Panics
///
/// Panics if `b` is not square.
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) {
let expect_msg = "If the Cholesky factorization succeeded,\
then the triangular solve should never fail";
// Solve LY = B
let mut y = b.into();
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y).expect(expect_msg);
// Solve L^T X = Y
let mut x = y;
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x).expect(expect_msg);
}
}
fn reach(
pattern: &SparsityPattern,
j: usize,
max_j: usize,
tree: &[usize],
marks: &mut Vec<bool>,
out: &mut Vec<usize>,
) {
marks.clear();
marks.resize(tree.len(), false);
// TODO: avoid all those allocations.
let mut tmp = Vec::new();
let mut res = Vec::new();
for &irow in pattern.lane(j) {
let mut curr = irow;
while curr != usize::max_value() && curr <= max_j && !marks[curr] {
marks[curr] = true;
tmp.push(curr);
curr = tree[curr];
}
tmp.append(&mut res);
mem::swap(&mut tmp, &mut res);
}
res.sort_unstable();
out.append(&mut res);
}
fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
let etree = elimination_tree(m);
// Note: We assume CSC, therefore rows == minor and cols == major
let (nrows, ncols) = (m.minor_dim(), m.major_dim());
let mut rows = Vec::with_capacity(m.nnz());
let mut col_offsets = Vec::with_capacity(ncols + 1);
let mut marks = Vec::new();
// NOTE: the following will actually compute the non-zero pattern of
// the transpose of l.
col_offsets.push(0);
for i in 0..nrows {
reach(m, i, i, &etree, &mut marks, &mut rows);
col_offsets.push(rows.len());
}
let u_pattern =
SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows).unwrap();
// TODO: Avoid this transpose?
let l_pattern = u_pattern.transpose();
(l_pattern, u_pattern)
}
fn elimination_tree(pattern: &SparsityPattern) -> Vec<usize> {
// Note: The pattern is assumed to of a CSC matrix, so the number of rows is
// given by the minor dimension
let nrows = pattern.minor_dim();
let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
for k in 0..nrows {
for &irow in pattern.lane(k) {
let mut i = irow;
while i < k {
let i_ancestor = ancestor[i];
ancestor[i] = k;
if i_ancestor == usize::max_value() {
forest[i] = k;
break;
}
i = i_ancestor;
}
}
}
forest
}

View File

@ -0,0 +1,6 @@
//! Matrix factorization for sparse matrices.
//!
//! Currently, the only factorization provided here is the [`CscCholesky`] factorization.
mod cholesky;
pub use cholesky::*;

267
nalgebra-sparse/src/lib.rs Normal file
View File

@ -0,0 +1,267 @@
//! Sparse matrices and algorithms for [nalgebra](https://www.nalgebra.org).
//!
//! This crate extends `nalgebra` with sparse matrix formats and operations on sparse matrices.
//!
//! ## Goals
//! The long-term goals for this crate are listed below.
//!
//! - Provide proven sparse matrix formats in an easy-to-use and idiomatic Rust API that
//! naturally integrates with `nalgebra`.
//! - Provide additional expert-level APIs for fine-grained control over operations.
//! - Integrate well with external sparse matrix libraries.
//! - Provide native Rust high-performance routines, including parallel matrix operations.
//!
//! ## Highlighted current features
//!
//! - [CSR](csr::CsrMatrix), [CSC](csc::CscMatrix) and [COO](coo::CooMatrix) formats, and
//! [conversions](`convert`) between them.
//! - Common arithmetic operations are implemented. See the [`ops`] module.
//! - Sparsity patterns in CSR and CSC matrices are explicitly represented by the
//! [SparsityPattern](pattern::SparsityPattern) type, which encodes the invariants of the
//! associated index data structures.
//! - [proptest strategies](`proptest`) for sparse matrices when the feature
//! `proptest-support` is enabled.
//! - [matrixcompare support](https://crates.io/crates/matrixcompare) for effortless
//! (approximate) comparison of matrices in test code (requires the `compare` feature).
//!
//! ## Current state
//!
//! The library is in an early, but usable state. The API has been designed to be extensible,
//! but breaking changes will be necessary to implement several planned features. While it is
//! backed by an extensive test suite, it has yet to be thoroughly battle-tested in real
//! applications. Moreover, the focus so far has been on correctness and API design, with little
//! focus on performance. Future improvements will include incremental performance enhancements.
//!
//! Current limitations:
//!
//! - Limited or no availability of sparse system solvers.
//! - Limited support for complex numbers. Currently only arithmetic operations that do not
//! rely on particular properties of complex numbers, such as e.g. conjugation, are
//! supported.
//! - No integration with external libraries.
//!
//! # Usage
//!
//! Add the following to your `Cargo.toml` file:
//!
//! ```toml
//! [dependencies]
//! nalgebra_sparse = "0.1"
//! ```
//!
//! # Supported matrix formats
//!
//! | Format | Notes |
//! | ------------------------|--------------------------------------------- |
//! | [COO](`coo::CooMatrix`) | Well-suited for matrix construction. <br /> Ill-suited for algebraic operations. |
//! | [CSR](`csr::CsrMatrix`) | Immutable sparsity pattern, suitable for algebraic operations. <br /> Fast row access. |
//! | [CSC](`csc::CscMatrix`) | Immutable sparsity pattern, suitable for algebraic operations. <br /> Fast column access. |
//!
//! What format is best to use depends on the application. The most common use case for sparse
//! matrices in science is the solution of sparse linear systems. Here we can differentiate between
//! two common cases:
//!
//! - Direct solvers. Typically, direct solvers take their input in CSR or CSC format.
//! - Iterative solvers. Many iterative solvers require only matrix-vector products,
//! for which the CSR or CSC formats are suitable.
//!
//! The [COO](coo::CooMatrix) format is primarily intended for matrix construction.
//! A common pattern is to use COO for construction, before converting to CSR or CSC for use
//! in a direct solver or for computing matrix-vector products in an iterative solver.
//! Some high-performance applications might also directly manipulate the CSR and/or CSC
//! formats.
//!
//! # Example: COO -> CSR -> matrix-vector product
//!
//! ```rust
//! use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
//! use nalgebra::{DMatrix, DVector};
//! use matrixcompare::assert_matrix_eq;
//!
//! // The dense representation of the matrix
//! let dense = DMatrix::from_row_slice(3, 3,
//! &[1.0, 0.0, 3.0,
//! 2.0, 0.0, 1.3,
//! 0.0, 0.0, 4.1]);
//!
//! // Build the equivalent COO representation. We only add the non-zero values
//! let mut coo = CooMatrix::new(3, 3);
//! // We can add elements in any order. For clarity, we do so in row-major order here.
//! coo.push(0, 0, 1.0);
//! coo.push(0, 2, 3.0);
//! coo.push(1, 0, 2.0);
//! coo.push(1, 2, 1.3);
//! coo.push(2, 2, 4.1);
//!
//! // The simplest way to construct a CSR matrix is to first construct a COO matrix, and
//! // then convert it to CSR. The `From` trait is implemented for conversions between different
//! // sparse matrix types.
//! // Alternatively, we can construct a matrix directly from the CSR data.
//! // See the docs for CsrMatrix for how to do that.
//! let csr = CsrMatrix::from(&coo);
//!
//! // Let's check that the CSR matrix and the dense matrix represent the same matrix.
//! // We can use macros from the `matrixcompare` crate to easily do this, despite the fact that
//! // we're comparing across two different matrix formats. Note that these macros are only really
//! // appropriate for writing tests, however.
//! assert_matrix_eq!(csr, dense);
//!
//! let x = DVector::from_column_slice(&[1.3, -4.0, 3.5]);
//!
//! // Compute the matrix-vector product y = A * x. We don't need to specify the type here,
//! // but let's just do it to make sure we get what we expect
//! let y: DVector<_> = &csr * &x;
//!
//! // Verify the result with a small element-wise absolute tolerance
//! let y_expected = DVector::from_column_slice(&[11.8, 7.15, 14.35]);
//! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9);
//!
//! // The above expression is simple, and gives easy to read code, but if we're doing this in a
//! // loop, we'll have to keep allocating new vectors. If we determine that this is a bottleneck,
//! // then we can resort to the lower level APIs for more control over the operations
//! {
//! use nalgebra_sparse::ops::{Op, serial::spmm_csr_dense};
//! let mut y = y;
//! // Compute y <- 0.0 * y + 1.0 * csr * dense. We store the result directly in `y`, without
//! // any intermediate allocations
//! spmm_csr_dense(0.0, &mut y, 1.0, Op::NoOp(&csr), Op::NoOp(&x));
//! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9);
//! }
//! ```
#![deny(non_camel_case_types)]
#![deny(unused_parens)]
#![deny(non_upper_case_globals)]
#![deny(unused_qualifications)]
#![deny(unused_results)]
#![deny(missing_docs)]
pub extern crate nalgebra as na;
pub mod convert;
pub mod coo;
pub mod csc;
pub mod csr;
pub mod factorization;
pub mod ops;
pub mod pattern;
pub(crate) mod cs;
#[cfg(feature = "proptest-support")]
pub mod proptest;
#[cfg(feature = "compare")]
mod matrixcompare;
use num_traits::Zero;
use std::error::Error;
use std::fmt;
pub use self::coo::CooMatrix;
pub use self::csc::CscMatrix;
pub use self::csr::CsrMatrix;
/// Errors produced by functions that expect well-formed sparse format data.
#[derive(Debug)]
pub struct SparseFormatError {
kind: SparseFormatErrorKind,
// Currently we only use an underlying error for generating the `Display` impl
error: Box<dyn Error>,
}
impl SparseFormatError {
/// The type of error.
pub fn kind(&self) -> &SparseFormatErrorKind {
&self.kind
}
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
Self { kind, error }
}
/// Helper functionality for more conveniently creating errors.
pub(crate) fn from_kind_and_msg(kind: SparseFormatErrorKind, msg: &'static str) -> Self {
Self::from_kind_and_error(kind, Box::<dyn Error>::from(msg))
}
}
/// The type of format error described by a [SparseFormatError](struct.SparseFormatError.html).
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SparseFormatErrorKind {
/// Indicates that the index data associated with the format contains at least one index
/// out of bounds.
IndexOutOfBounds,
/// Indicates that the provided data contains at least one duplicate entry, and the
/// current format does not support duplicate entries.
DuplicateEntry,
/// Indicates that the provided data for the format does not conform to the high-level
/// structure of the format.
///
/// For example, the arrays defining the format data might have incompatible sizes.
InvalidStructure,
}
impl fmt::Display for SparseFormatError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.error)
}
}
impl Error for SparseFormatError {}
/// An entry in a sparse matrix.
///
/// Sparse matrices do not store all their entries explicitly. Therefore, entry (i, j) in the matrix
/// can either be a reference to an explicitly stored element, or it is implicitly zero.
#[derive(Debug, PartialEq, Eq)]
pub enum SparseEntry<'a, T> {
/// The entry is a reference to an explicitly stored element.
///
/// Note that the naming here is a misnomer: The element can still be zero, even though it
/// is explicitly stored (a so-called "explicit zero").
NonZero(&'a T),
/// The entry is implicitly zero, i.e. it is not explicitly stored.
Zero,
}
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
/// Returns the value represented by this entry.
///
/// Either clones the underlying reference or returns zero if the entry is not explicitly
/// stored.
pub fn into_value(self) -> T {
match self {
SparseEntry::NonZero(value) => value.clone(),
SparseEntry::Zero => T::zero(),
}
}
}
/// A mutable entry in a sparse matrix.
///
/// See also `SparseEntry`.
#[derive(Debug, PartialEq, Eq)]
pub enum SparseEntryMut<'a, T> {
/// The entry is a mutable reference to an explicitly stored element.
///
/// Note that the naming here is a misnomer: The element can still be zero, even though it
/// is explicitly stored (a so-called "explicit zero").
NonZero(&'a mut T),
/// The entry is implicitly zero i.e. it is not explicitly stored.
Zero,
}
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
/// Returns the value represented by this entry.
///
/// Either clones the underlying reference or returns zero if the entry is not explicitly
/// stored.
pub fn into_value(self) -> T {
match self {
SparseEntryMut::NonZero(value) => value.clone(),
SparseEntryMut::Zero => T::zero(),
}
}
}

View File

@ -0,0 +1,65 @@
//! Implements core traits for use with `matrixcompare`.
use crate::coo::CooMatrix;
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use matrixcompare_core;
use matrixcompare_core::{Access, SparseAccess};
macro_rules! impl_matrix_for_csr_csc {
($MatrixType:ident) => {
impl<T: Clone> SparseAccess<T> for $MatrixType<T> {
fn nnz(&self) -> usize {
$MatrixType::nnz(self)
}
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
self.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect()
}
}
impl<T: Clone> matrixcompare_core::Matrix<T> for $MatrixType<T> {
fn rows(&self) -> usize {
self.nrows()
}
fn cols(&self) -> usize {
self.ncols()
}
fn access(&self) -> Access<T> {
Access::Sparse(self)
}
}
};
}
impl_matrix_for_csr_csc!(CsrMatrix);
impl_matrix_for_csr_csc!(CscMatrix);
impl<T: Clone> SparseAccess<T> for CooMatrix<T> {
fn nnz(&self) -> usize {
CooMatrix::nnz(self)
}
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
self.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect()
}
}
impl<T: Clone> matrixcompare_core::Matrix<T> for CooMatrix<T> {
fn rows(&self) -> usize {
self.nrows()
}
fn cols(&self) -> usize {
self.ncols()
}
fn access(&self) -> Access<T> {
Access::Sparse(self)
}
}

View File

@ -0,0 +1,331 @@
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use crate::ops::serial::{
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
spmm_csc_prealloc, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc,
};
use crate::ops::Op;
use nalgebra::allocator::Allocator;
use nalgebra::base::storage::Storage;
use nalgebra::constraint::{DimEq, ShapeConstraint};
use nalgebra::{
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dynamic, Matrix, MatrixMN,
Scalar, U1,
};
use num_traits::{One, Zero};
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
/// Helper macro for implementing binary operators for different matrix types
/// See below for usage.
macro_rules! impl_bin_op {
($trait:ident, $method:ident,
<$($life:lifetime),* $(,)? $($scalar_type:ident $(: $bounds:path)?)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
=>
{
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
where
// Note: The Neg bound is currently required because we delegate e.g.
// Sub to SpAdd with negative coefficients. This is not well-defined for
// unsigned data types.
$($scalar_type: $($bounds + )? Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
{
type Output = $ret;
fn $method(self, $b: $b_type) -> Self::Output {
let $a = self;
$body
}
}
};
}
/// Implements a +/- b for all combinations of reference and owned matrices, for
/// CsrMatrix or CscMatrix.
macro_rules! impl_sp_plus_minus {
// We first match on some special-case syntax, and forward to the actual implementation
($matrix_type:ident, $spadd_fn:ident, +) => {
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
};
($matrix_type:ident, $spadd_fn:ident, -) => {
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
};
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
// If both matrices have the same pattern, then we can immediately re-use it
let pattern = spadd_pattern(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
// We are giving data that is valid by definition, so it is safe to unwrap below
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap();
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
$spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap();
result
});
impl_bin_op!($trait, $method,
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
&a $sign b
});
impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a $sign &b
});
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a $sign &b
});
}
}
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
macro_rules! impl_mul {
($($args:tt)*) => {
impl_bin_op!(Mul, mul, $($args)*);
}
}
/// Implements a + b for all combinations of reference and owned matrices, for
/// CsrMatrix or CscMatrix.
macro_rules! impl_spmm {
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let pattern = $pattern_fn(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap();
$spmm_fn(T::zero(),
&mut result,
T::one(),
Op::NoOp(a),
Op::NoOp(b))
.expect("Internal error: spmm failed (please debug).");
result
});
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
impl_mul!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
}
}
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc);
// Need to switch order of operations for CSC pattern
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc);
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
macro_rules! impl_concrete_scalar_matrix_mul {
($matrix_type:ident, $($scalar_type:ty),*) => {
// For each concrete scalar type, forward the implementation of scalar * matrix
// to matrix * scalar, which we have already implemented through generics
$(
impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * a });
impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * a });
impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * (*a) });
impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * *a });
)*
}
}
/// Implements multiplication between matrix and scalar for various matrix types
macro_rules! impl_scalar_mul {
($matrix_type: ident) => {
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
let values: Vec<_> = a.values()
.iter()
.map(|v_i| v_i.inlined_clone() * b.inlined_clone())
.collect();
$matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
});
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
a * &b
});
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
let mut a = a;
for value in a.values_mut() {
*value = b.inlined_clone() * value.inlined_clone();
}
a
});
impl_mul!(<T>(a: $matrix_type<T>, b: T) -> $matrix_type<T> {
a * &b
});
impl_concrete_scalar_matrix_mul!(
$matrix_type,
i8, i16, i32, i64, isize, f32, f64);
impl<T> MulAssign<T> for $matrix_type<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
fn mul_assign(&mut self, scalar: T) {
for val in self.values_mut() {
*val *= scalar.inlined_clone();
}
}
}
impl<'a, T> MulAssign<&'a T> for $matrix_type<T>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One
{
fn mul_assign(&mut self, scalar: &'a T) {
for val in self.values_mut() {
*val *= scalar.inlined_clone();
}
}
}
}
}
impl_scalar_mul!(CsrMatrix);
impl_scalar_mul!(CscMatrix);
macro_rules! impl_neg {
($matrix_type:ident) => {
impl<T> Neg for $matrix_type<T>
where
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
fn neg(mut self) -> Self::Output {
for v_i in self.values_mut() {
*v_i = -v_i.inlined_clone();
}
self
}
}
impl<'a, T> Neg for &'a $matrix_type<T>
where
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
fn neg(self) -> Self::Output {
// TODO: This is inefficient. Ideally we'd have a method that would let us
// obtain both the sparsity pattern and values from the matrix,
// and then modify the values before creating a new matrix from the pattern
// and negated values.
-self.clone()
}
}
};
}
impl_neg!(CsrMatrix);
impl_neg!(CscMatrix);
macro_rules! impl_div {
($matrix_type:ident) => {
impl_bin_op!(Div, div, <T: ClosedDiv>(matrix: $matrix_type<T>, scalar: T) -> $matrix_type<T> {
let mut matrix = matrix;
matrix /= scalar;
matrix
});
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: $matrix_type<T>, scalar: &T) -> $matrix_type<T> {
matrix / scalar.inlined_clone()
});
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: T) -> $matrix_type<T> {
let new_values = matrix.values()
.iter()
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
.collect();
$matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
.unwrap()
});
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
matrix / scalar.inlined_clone()
});
impl<T> DivAssign<T> for $matrix_type<T>
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
{
fn div_assign(&mut self, scalar: T) {
self.values_mut().iter_mut().for_each(|v_i| *v_i /= scalar.inlined_clone());
}
}
impl<'a, T> DivAssign<&'a T> for $matrix_type<T>
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
{
fn div_assign(&mut self, scalar: &'a T) {
*self /= scalar.inlined_clone();
}
}
}
}
impl_div!(CsrMatrix);
impl_div!(CscMatrix);
macro_rules! impl_spmm_cs_dense {
($matrix_type_name:ident, $spmm_fn:ident) => {
// Implement ref-ref
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
let (_, ncols) = rhs.data.shape();
let nrows = Dynamic::new(lhs.nrows());
let mut result = MatrixMN::<T, Dynamic, C>::zeros_generic(nrows, ncols);
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs));
result
});
// Implement the other combinations by deferring to ref-ref
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
lhs * &rhs
});
impl_spmm_cs_dense!($matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
&lhs * rhs
});
impl_spmm_cs_dense!($matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
&lhs * &rhs
});
};
// Main body of the macro. The first pattern just forwards to this pattern but with
// different arguments
($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident,
|$lhs:ident, $rhs:ident| $body:tt) =>
{
impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type
where
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
R: Dim,
C: Dim,
S: Storage<T, R, C>,
DefaultAllocator: Allocator<T, Dynamic, C>,
// TODO: Is it possible to simplify these bounds?
ShapeConstraint:
// Bounds so that we can turn MatrixMN<T, Dynamic, C> into a DMatrixSliceMut
DimEq<U1, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::RStride>
+ DimEq<C, Dynamic>
+ DimEq<Dynamic, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::CStride>
// Bounds so that we can turn &Matrix<T, R, C, S> into a DMatrixSlice
+ DimEq<U1, S::RStride>
+ DimEq<R, Dynamic>
+ DimEq<Dynamic, S::CStride>
{
// We need the column dimension to be generic, so that if RHS is a vector, then
// we also get a vector (and not a matrix)
type Output = MatrixMN<T, Dynamic, C>;
fn mul(self, rhs: $dense_matrix_type) -> Self::Output {
let $lhs = self;
let $rhs = rhs;
$body
}
}
}
}
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);

View File

@ -0,0 +1,194 @@
//! Sparse matrix arithmetic operations.
//!
//! This module contains a number of routines for sparse matrix arithmetic. These routines are
//! primarily intended for "expert usage". Most users should prefer to use standard
//! `std::ops` operations for simple and readable code when possible. The routines provided here
//! offer more control over allocation, and allow fusing some low-level operations for higher
//! performance.
//!
//! The available operations are organized by backend. Currently, only the [`serial`] backend
//! is available. In the future, backends that expose parallel operations may become available.
//! All `std::ops` implementations will remain single-threaded and powered by the
//! `serial` backend.
//!
//! Many routines are able to implicitly transpose matrices involved in the operation.
//! For example, the routine [`spadd_csr_prealloc`](serial::spadd_csr_prealloc) performs the
//! operation `C <- beta * C + alpha * op(A)`. Here `op(A)` indicates that the matrix `A` can
//! either be used as-is or transposed. The notation `op(A)` is represented in code by the
//! [`Op`] enum.
//!
//! # Available `std::ops` implementations
//!
//! ## Binary operators
//!
//! The below table summarizes the currently supported binary operators between matrices.
//! In general, binary operators between sparse matrices are only supported if both matrices
//! are stored in the same format. All supported binary operators are implemented for
//! all four combinations of values and references.
//!
//! <table>
//! <tr>
//! <th>LHS (down) \ RHS (right)</th>
//! <th>COO</th>
//! <th>CSR</th>
//! <th>CSC</th>
//! <th>Dense</th>
//! </tr>
//! <tr>
//! <th>COO</th>
//! <td></td>
//! <td></td>
//! <td></td>
//! <td></td>
//! </tr>
//! <tr>
//! <th>CSR</th>
//! <td></td>
//! <td>+ - *</td>
//! <td></td>
//! <td>*</td>
//! </tr>
//! <tr>
//! <th>CSC</th>
//! <td></td>
//! <td></td>
//! <td>+ - *</td>
//! <td>*</td>
//! </tr>
//! <tr>
//! <th>Dense</th>
//! <td></td>
//! <td></td>
//! <td></td>
//! <td>+ - *</td>
//! </tr>
//! </table>
//!
//! As can be seen from the table, only `CSR * Dense` and `CSC * Dense` are supported.
//! The other way around, i.e. `Dense * CSR` and `Dense * CSC` are not implemented.
//!
//! Additionally, [CsrMatrix](`crate::csr::CsrMatrix`) and [CooMatrix](`crate::coo::CooMatrix`)
//! support multiplication with scalars, in addition to division by a scalar.
//! Note that only `Matrix * Scalar` works in a generic context, although `Scalar * Matrix`
//! has been implemented for many of the built-in arithmetic types. This is due to a fundamental
//! restriction of the Rust type system. Therefore, in generic code you will need to always place
//! the matrix on the left-hand side of the multiplication.
//!
//! ## Unary operators
//!
//! The following table lists currently supported unary operators.
//!
//! | Format | AddAssign\<Matrix\> | MulAssign\<Matrix\> | MulAssign\<Scalar\> | Neg |
//! | -------- | ----------------- | ----------------- | ------------------- | ------ |
//! | COO | | | | |
//! | CSR | | | x | x |
//! | CSC | | | x | x |
//! |
//! # Example usage
//!
//! For example, consider the case where you want to compute the expression
//! `C <- 3.0 * C + 2.0 * A^T * B`, where `A`, `B`, `C` are matrices and `A^T` is the transpose
//! of `A`. The simplest way to write this is:
//!
//! ```rust
//! # use nalgebra_sparse::csr::CsrMatrix;
//! # let a = CsrMatrix::identity(10); let b = CsrMatrix::identity(10);
//! # let mut c = CsrMatrix::identity(10);
//! c = 3.0 * c + 2.0 * a.transpose() * b;
//! ```
//! This is simple and straightforward to read, and therefore the recommended way to implement
//! it. However, if you have determined that this is a performance bottleneck of your application,
//! it may be possible to speed things up. First, let's see what's going on here. The `std`
//! operations are evaluated eagerly. This means that the following steps take place:
//!
//! 1. Evaluate `let c_temp = 3.0 * c`. This requires scaling all values of the matrix.
//! 2. Evaluate `let a_t = a.transpose()` into a new temporary matrix.
//! 3. Evaluate `let a_t_b = a_t * b` into a new temporary matrix.
//! 4. Evaluate `let a_t_b_scaled = 2.0 * a_t_b`. This requires scaling all values of the matrix.
//! 5. Evaluate `c = c_temp + a_t_b_scaled`.
//!
//! An alternative way to implement this expression (here using CSR matrices) is:
//!
//! ```rust
//! # use nalgebra_sparse::csr::CsrMatrix;
//! # let a = CsrMatrix::identity(10); let b = CsrMatrix::identity(10);
//! # let mut c = CsrMatrix::identity(10);
//! use nalgebra_sparse::ops::{Op, serial::spmm_csr_prealloc};
//!
//! // Evaluate the expression `c <- 3.0 * c + 2.0 * a^T * b
//! spmm_csr_prealloc(3.0, &mut c, 2.0, Op::Transpose(&a), Op::NoOp(&b))
//! .expect("We assume that the pattern of C is able to accommodate the result.");
//! ```
//! Compared to the simpler example, this snippet is harder to read, but it calls a single
//! computational kernel that avoids many of the intermediate steps listed out before. Therefore
//! directly calling kernels may sometimes lead to better performance. However, this should
//! always be verified by performance profiling!
mod impl_std_ops;
pub mod serial;
/// Determines whether a matrix should be transposed in a given operation.
///
/// See the [module-level documentation](crate::ops) for the purpose of this enum.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Op<T> {
/// Indicates that the matrix should be used as-is.
NoOp(T),
/// Indicates that the matrix should be transposed.
Transpose(T),
}
impl<T> Op<T> {
/// Returns a reference to the inner value that the operation applies to.
pub fn inner_ref(&self) -> &T {
self.as_ref().into_inner()
}
/// Returns an `Op` applied to a reference of the inner value.
pub fn as_ref(&self) -> Op<&T> {
match self {
Op::NoOp(obj) => Op::NoOp(&obj),
Op::Transpose(obj) => Op::Transpose(&obj),
}
}
/// Converts the underlying data type.
pub fn convert<U>(self) -> Op<U>
where
T: Into<U>,
{
self.map_same_op(T::into)
}
/// Transforms the inner value with the provided function, but preserves the operation.
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
match self {
Op::NoOp(obj) => Op::NoOp(f(obj)),
Op::Transpose(obj) => Op::Transpose(f(obj)),
}
}
/// Consumes the `Op` and returns the inner value.
pub fn into_inner(self) -> T {
match self {
Op::NoOp(obj) | Op::Transpose(obj) => obj,
}
}
/// Applies the transpose operation.
///
/// This operation follows the usual semantics of transposition. In particular, double
/// transposition is equivalent to no transposition.
pub fn transposed(self) -> Self {
match self {
Op::NoOp(obj) => Op::Transpose(obj),
Op::Transpose(obj) => Op::NoOp(obj),
}
}
}
impl<T> From<T> for Op<T> {
fn from(obj: T) -> Self {
Self::NoOp(obj)
}
}

View File

@ -0,0 +1,186 @@
use crate::cs::CsMatrix;
use crate::ops::serial::{OperationError, OperationErrorKind};
use crate::ops::Op;
use crate::SparseEntryMut;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
use num_traits::{One, Zero};
fn spmm_cs_unexpected_entry() -> OperationError {
OperationError::from_kind_and_message(
OperationErrorKind::InvalidPattern,
String::from("Found unexpected entry that is not present in `c`."),
)
}
/// Helper functionality for implementing CSR/CSC SPMM.
///
/// Since CSR/CSC matrices are basically transpositions of each other, which lets us use the same
/// algorithm for the SPMM implementation. The implementation here is written in a CSR-centric
/// manner. This means that when using it for CSC, the order of the matrices needs to be
/// reversed (since transpose(AB) = transpose(B) * transpose(A) and CSC(A) = transpose(CSR(A)).
///
/// We assume here that the matrices have already been verified to be dimensionally compatible.
pub fn spmm_cs_prealloc<T>(
beta: T,
c: &mut CsMatrix<T>,
alpha: T,
a: &CsMatrix<T>,
b: &CsMatrix<T>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
for i in 0..c.pattern().major_dim() {
let a_lane_i = a.get_lane(i).unwrap();
let mut c_lane_i = c.get_lane_mut(i).unwrap();
for c_ij in c_lane_i.values_mut() {
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
}
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
let b_lane_k = b.get_lane(k).unwrap();
let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
// Determine the location in C to append the value
let (c_local_idx, _) = c_lane_i_cols
.iter()
.enumerate()
.find(|(_, c_col)| *c_col == j)
.ok_or_else(spmm_cs_unexpected_entry)?;
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
}
}
}
Ok(())
}
fn spadd_cs_unexpected_entry() -> OperationError {
OperationError::from_kind_and_message(
OperationErrorKind::InvalidPattern,
String::from("Found entry in `op(a)` that is not present in `c`."),
)
}
/// Helper functionality for implementing CSR/CSC SPADD.
pub fn spadd_cs_prealloc<T>(
beta: T,
c: &mut CsMatrix<T>,
alpha: T,
a: Op<&CsMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
match a {
Op::NoOp(a) => {
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
if beta != T::one() {
for c_ij in c_lane_i.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
// TODO: Use exponential search instead of linear search.
// If C has substantially more entries in the row than A, then a line search
// will needlessly visit many entries in C.
let (c_idx, _) = c_minors
.iter()
.enumerate()
.find(|(_, c_col)| *c_col == a_col)
.ok_or_else(spadd_cs_unexpected_entry)?;
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
c_minors = &c_minors[c_idx..];
c_vals = &mut c_vals[c_idx..];
}
}
}
Op::Transpose(a) => {
if beta != T::one() {
for c_ij in c.values_mut() {
*c_ij *= beta.inlined_clone();
}
}
for (i, a_lane_i) in a.lane_iter().enumerate() {
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
let a_val = a_val.inlined_clone();
let alpha = alpha.inlined_clone();
match c.get_entry_mut(j, i).unwrap() {
SparseEntryMut::NonZero(c_ji) => *c_ji += alpha * a_val,
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
}
}
}
}
}
Ok(())
}
/// Helper functionality for implementing CSR/CSC SPMM.
///
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
/// the transposed operation must be specified for the CSC matrix.
pub fn spmm_cs_dense<T>(
beta: T,
mut c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CsMatrix<T>>,
b: Op<DMatrixSlice<T>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
match a {
Op::NoOp(a) => {
for j in 0..c.ncols() {
let mut c_col_j = c.column_mut(j);
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
let mut dot_ij = T::zero();
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
let b_contrib = match b {
Op::NoOp(ref b) => b.index((k, j)),
Op::Transpose(ref b) => b.index((j, k)),
};
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
}
*c_ij = beta.inlined_clone() * c_ij.inlined_clone()
+ alpha.inlined_clone() * dot_ij;
}
}
}
Op::Transpose(a) => {
// In this case, we have to pre-multiply C by beta
c *= beta;
for k in 0..a.pattern().major_dim() {
let a_row_k = a.get_lane(k).unwrap();
for (&i, a_ki) in a_row_k.minor_indices().iter().zip(a_row_k.values()) {
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
let mut c_row_i = c.row_mut(i);
match b {
Op::NoOp(ref b) => {
let b_row_k = b.row(k);
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
}
}
Op::Transpose(ref b) => {
let b_col_k = b.column(k);
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
}
}
}
}
}
}
}
}

View File

@ -0,0 +1,255 @@
use crate::csc::CscMatrix;
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
use crate::ops::serial::{OperationError, OperationErrorKind};
use crate::ops::Op;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
use num_traits::{One, Zero};
use std::borrow::Cow;
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
///
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spmm_csc_dense<'a, T>(
beta: T,
c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
let b = b.convert();
spmm_csc_dense_(beta, c.into(), alpha, a, b)
}
fn spmm_csc_dense_<T>(
beta: T,
c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<DMatrixSlice<T>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
let a = a.transposed().map_same_op(|a| &a.cs);
spmm_cs_dense(beta, c, alpha, a, b)
}
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
///
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
/// returned.
///
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spadd_csc_prealloc<T>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
}
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
///
/// # Errors
///
/// If the sparsity pattern of `C` is not able to store the result of the operation,
/// an error is returned.
///
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spmm_csc_prealloc<T>(
beta: T,
c: &mut CscMatrix<T>,
alpha: T,
a: Op<&CscMatrix<T>>,
b: Op<&CscMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
use Op::{NoOp, Transpose};
match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => {
// Note: We have to reverse the order for CSC matrices
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
}
_ => {
// Currently we handle transposition by explicitly precomputing transposed matrices
// and calling the operation again without transposition
let a_ref: &CscMatrix<T> = a.inner_ref();
let b_ref: &CscMatrix<T> = b.inner_ref();
let (a, b) = {
use Cow::*;
match (&a, &b) {
(NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(ref a), Transpose(ref b)) => {
(Owned(a.transpose()), Owned(b.transpose()))
}
}
};
spmm_csc_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
}
}
}
/// Solve the lower triangular system `op(L) X = B`.
///
/// Only the lower triangular part of L is read, and the result is stored in B.
///
/// # Errors
///
/// An error is returned if the system can not be solved due to the matrix being singular.
///
/// # Panics
///
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
l: Op<&CscMatrix<T>>,
b: impl Into<DMatrixSliceMut<'a, T>>,
) -> Result<(), OperationError> {
let b = b.into();
let l_matrix = l.into_inner();
assert_eq!(
l_matrix.nrows(),
l_matrix.ncols(),
"Matrix must be square for triangular solve."
);
assert_eq!(
l_matrix.nrows(),
b.nrows(),
"Dimension mismatch in sparse lower triangular solver."
);
match l {
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
}
}
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
l: &CscMatrix<T>,
b: DMatrixSliceMut<T>,
) -> Result<(), OperationError> {
let mut x = b;
// Solve column-by-column
for j in 0..x.ncols() {
let mut x_col_j = x.column_mut(j);
for k in 0..l.ncols() {
let l_col_k = l.col(k);
// Skip entries above the diagonal
// TODO: Can use exponential search here to quickly skip entries
// (we'd like to avoid using binary search as it's very cache unfriendly
// and the matrix might actually *be* lower triangular, which would induce
// a severe penalty)
let diag_csc_index = l_col_k.row_indices().iter().position(|&i| i == k);
if let Some(diag_csc_index) = diag_csc_index {
let l_kk = l_col_k.values()[diag_csc_index];
if l_kk != T::zero() {
// Update entry associated with diagonal
x_col_j[k] /= l_kk;
// Copy value after updating (so we don't run into the borrow checker)
let x_kj = x_col_j[k];
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
// Note: The remaining entries are below the diagonal
for (&i, l_ik) in row_indices.iter().zip(l_values) {
let x_ij = &mut x_col_j[i];
*x_ij -= l_ik.inlined_clone() * x_kj;
}
x_col_j[k] = x_kj;
} else {
return spsolve_encountered_zero_diagonal();
}
} else {
return spsolve_encountered_zero_diagonal();
}
}
}
Ok(())
}
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
let message = "Matrix contains at least one diagonal entry that is zero.";
Err(OperationError::from_kind_and_message(
OperationErrorKind::Singular,
String::from(message),
))
}
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
l: &CscMatrix<T>,
b: DMatrixSliceMut<T>,
) -> Result<(), OperationError> {
let mut x = b;
// Solve column-by-column
for j in 0..x.ncols() {
let mut x_col_j = x.column_mut(j);
// Due to the transposition, we're essentially solving an upper triangular system,
// and the columns in our matrix become rows
for i in (0..l.ncols()).rev() {
let l_col_i = l.col(i);
// Skip entries above the diagonal
// TODO: Can use exponential search here to quickly skip entries
let diag_csc_index = l_col_i.row_indices().iter().position(|&k| i == k);
if let Some(diag_csc_index) = diag_csc_index {
let l_ii = l_col_i.values()[diag_csc_index];
if l_ii != T::zero() {
// // Update entry associated with diagonal
// x_col_j[k] /= a_kk;
// Copy value after updating (so we don't run into the borrow checker)
let mut x_ii = x_col_j[i];
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
// Note: The remaining entries are below the diagonal
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
let x_kj = x_col_j[k];
x_ii -= l_ki * x_kj;
}
x_col_j[i] = x_ii / l_ii;
} else {
return spsolve_encountered_zero_diagonal();
}
} else {
return spsolve_encountered_zero_diagonal();
}
}
}
Ok(())
}

View File

@ -0,0 +1,106 @@
use crate::csr::CsrMatrix;
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
use crate::ops::serial::OperationError;
use crate::ops::Op;
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
use num_traits::{One, Zero};
use std::borrow::Cow;
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
pub fn spmm_csr_dense<'a, T>(
beta: T,
c: impl Into<DMatrixSliceMut<'a, T>>,
alpha: T,
a: Op<&CsrMatrix<T>>,
b: Op<impl Into<DMatrixSlice<'a, T>>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
let b = b.convert();
spmm_csr_dense_(beta, c.into(), alpha, a, b)
}
fn spmm_csr_dense_<T>(
beta: T,
c: DMatrixSliceMut<T>,
alpha: T,
a: Op<&CsrMatrix<T>>,
b: Op<DMatrixSlice<T>>,
) where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
}
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
///
/// # Errors
///
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
/// returned.
///
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spadd_csr_prealloc<T>(
beta: T,
c: &mut CsrMatrix<T>,
alpha: T,
a: Op<&CsrMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spadd_dims!(c, a);
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
}
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
///
/// # Errors
///
/// If the pattern of `C` is not able to hold the result of the operation, an error is returned.
///
/// # Panics
///
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
pub fn spmm_csr_prealloc<T>(
beta: T,
c: &mut CsrMatrix<T>,
alpha: T,
a: Op<&CsrMatrix<T>>,
b: Op<&CsrMatrix<T>>,
) -> Result<(), OperationError>
where
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
{
assert_compatible_spmm_dims!(c, a, b);
use Op::{NoOp, Transpose};
match (&a, &b) {
(NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
_ => {
// Currently we handle transposition by explicitly precomputing transposed matrices
// and calling the operation again without transposition
// TODO: At least use workspaces to allow control of allocations. Maybe
// consider implementing certain patterns (like A^T * B) explicitly
let a_ref: &CsrMatrix<T> = a.inner_ref();
let b_ref: &CsrMatrix<T> = b.inner_ref();
let (a, b) = {
use Cow::*;
match (&a, &b) {
(NoOp(_), NoOp(_)) => unreachable!(),
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
(Transpose(ref a), Transpose(ref b)) => {
(Owned(a.transpose()), Owned(b.transpose()))
}
}
};
spmm_csr_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
}
}
}

View File

@ -0,0 +1,124 @@
//! Serial sparse matrix arithmetic routines.
//!
//! All routines are single-threaded.
//!
//! Some operations have the `prealloc` suffix. This means that they expect that the sparsity
//! pattern of the output matrix has already been pre-allocated: that is, the pattern of the result
//! of the operation fits entirely in the output pattern. In the future, there will also be
//! some operations which will be able to dynamically adapt the output pattern to fit the
//! result, but these have yet to be implemented.
#[macro_use]
macro_rules! assert_compatible_spmm_dims {
($c:expr, $a:expr, $b:expr) => {{
use crate::ops::Op::{NoOp, Transpose};
match (&$a, &$b) {
(NoOp(ref a), NoOp(ref b)) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
}
(Transpose(ref a), NoOp(ref b)) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
}
(NoOp(ref a), Transpose(ref b)) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
}
(Transpose(ref a), Transpose(ref b)) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
}
}
}};
}
#[macro_use]
macro_rules! assert_compatible_spadd_dims {
($c:expr, $a:expr) => {
use crate::ops::Op;
match $a {
Op::NoOp(a) => {
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
}
Op::Transpose(a) => {
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
}
}
};
}
mod cs;
mod csc;
mod csr;
mod pattern;
pub use csc::*;
pub use csr::*;
pub use pattern::*;
use std::fmt;
use std::fmt::Formatter;
/// A description of the error that occurred during an arithmetic operation.
#[derive(Clone, Debug)]
pub struct OperationError {
error_kind: OperationErrorKind,
message: String,
}
/// The different kinds of operation errors that may occur.
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum OperationErrorKind {
/// Indicates that one or more sparsity patterns involved in the operation violate the
/// expectations of the routine.
///
/// For example, this could indicate that the sparsity pattern of the output is not able to
/// contain the result of the operation.
InvalidPattern,
/// Indicates that a matrix is singular when it is expected to be invertible.
Singular,
}
impl OperationError {
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
Self {
error_kind: error_type,
message,
}
}
/// The operation error kind.
pub fn kind(&self) -> &OperationErrorKind {
&self.error_kind
}
/// The underlying error message.
pub fn message(&self) -> &str {
self.message.as_str()
}
}
impl fmt::Display for OperationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Sparse matrix operation error: ")?;
match self.kind() {
OperationErrorKind::InvalidPattern => {
write!(f, "InvalidPattern")?;
}
OperationErrorKind::Singular => {
write!(f, "Singular")?;
}
}
write!(f, " Message: {}", self.message)
}
}
impl std::error::Error for OperationError {}

View File

@ -0,0 +1,152 @@
use crate::pattern::SparsityPattern;
use std::iter;
/// Sparse matrix addition pattern construction, `C <- A + B`.
///
/// Builds the pattern for `C`, which is able to hold the result of the sum `A + B`.
/// The patterns are assumed to have the same major and minor dimensions. In other words,
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
/// CSR or CSC.
///
/// # Panics
///
/// Panics if the patterns do not have the same major and minor dimensions.
pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(
a.major_dim(),
b.major_dim(),
"Patterns must have identical major dimensions."
);
assert_eq!(
a.minor_dim(),
b.minor_dim(),
"Patterns must have identical minor dimensions."
);
let mut offsets = Vec::new();
let mut indices = Vec::new();
offsets.reserve(a.major_dim() + 1);
indices.clear();
offsets.push(0);
for lane_idx in 0..a.major_dim() {
let lane_a = a.lane(lane_idx);
let lane_b = b.lane(lane_idx);
indices.extend(iterate_union(lane_a, lane_b));
offsets.push(indices.len());
}
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
.expect("Internal error: Pattern must be valid by definition")
}
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
///
/// Assumes that the sparsity patterns both represent CSC matrices, and the result is also
/// represented as the sparsity pattern of a CSC matrix.
///
/// # Panics
///
/// Panics if the patterns, when interpreted as CSC patterns, are not compatible for
/// matrix multiplication.
pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
// Let C = A * B in CSC format. We note that
// C^T = B^T * A^T.
// Since the interpretation of a CSC matrix in CSR format represents the transpose of the
// matrix in CSR, we can compute C^T in *CSR format* by switching the order of a and b,
// which lets us obtain C^T in CSR format. Re-interpreting this as CSC gives us C in CSC format
spmm_csr_pattern(b, a)
}
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
///
/// Assumes that the sparsity patterns both represent CSR matrices, and the result is also
/// represented as the sparsity pattern of a CSR matrix.
///
/// # Panics
///
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
/// matrix multiplication.
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
assert_eq!(
a.minor_dim(),
b.major_dim(),
"a and b must have compatible dimensions"
);
let mut offsets = Vec::new();
let mut indices = Vec::new();
offsets.push(0);
// Keep a vector of whether we have visited a particular minor index when working
// on a major lane
// TODO: Consider using a bitvec or similar here to reduce pressure on memory
// (would cut memory use to 1/8, which might help reduce cache misses)
let mut visited = vec![false; b.minor_dim()];
for i in 0..a.major_dim() {
let a_lane_i = a.lane(i);
let c_lane_i_offset = *offsets.last().unwrap();
for &k in a_lane_i {
let b_lane_k = b.lane(k);
for &j in b_lane_k {
let have_visited_j = &mut visited[j];
if !*have_visited_j {
indices.push(j);
*have_visited_j = true;
}
}
}
let c_lane_i = &mut indices[c_lane_i_offset..];
c_lane_i.sort_unstable();
// Reset visits so that visited[j] == false for all j for the next major lane
for j in c_lane_i {
visited[*j] = false;
}
offsets.push(indices.len());
}
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), b.minor_dim(), offsets, indices)
.expect("Internal error: Invalid pattern during matrix multiplication pattern construction")
}
/// Iterate over the union of the two sets represented by sorted slices
/// (with unique elements)
fn iterate_union<'a>(
mut sorted_a: &'a [usize],
mut sorted_b: &'a [usize],
) -> impl Iterator<Item = usize> + 'a {
iter::from_fn(move || {
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
let item = if a_item < b_item {
sorted_a = &sorted_a[1..];
a_item
} else if b_item < a_item {
sorted_b = &sorted_b[1..];
b_item
} else {
// Both lists contain the same element, advance both slices to avoid
// duplicate entries in the result
sorted_a = &sorted_a[1..];
sorted_b = &sorted_b[1..];
a_item
};
Some(*item)
} else if let Some(a_item) = sorted_a.first() {
sorted_a = &sorted_a[1..];
Some(*a_item)
} else if let Some(b_item) = sorted_b.first() {
sorted_b = &sorted_b[1..];
Some(*b_item)
} else {
None
}
})
}

View File

@ -0,0 +1,393 @@
//! Sparsity patterns for CSR and CSC matrices.
use crate::cs::transpose_cs;
use crate::SparseFormatError;
use std::error::Error;
use std::fmt;
/// A representation of the sparsity pattern of a CSR or CSC matrix.
///
/// CSR and CSC matrices store matrices in a very similar fashion. In fact, in a certain sense,
/// they are transposed. More precisely, when reinterpreting the three data arrays of a CSR
/// matrix as a CSC matrix, we obtain the CSC representation of its transpose.
///
/// [`SparsityPattern`] is an abstraction built on this observation. Whereas CSR matrices
/// store a matrix row-by-row, and a CSC matrix stores a matrix column-by-column, a
/// `SparsityPattern` represents only the index data structure of a matrix *lane-by-lane*.
/// Here, a *lane* is a generalization of rows and columns. We further define *major lanes*
/// and *minor lanes*. The sparsity pattern of a CSR matrix is then obtained by interpreting
/// major/minor as row/column. Conversely, we obtain the sparsity pattern of a CSC matrix by
/// interpreting major/minor as column/row.
///
/// This allows us to use a common abstraction to talk about sparsity patterns of CSR and CSC
/// matrices. This is convenient, because at the abstract level, the invariants of the formats
/// are the same. Hence we may encode the invariants of the index data structure separately from
/// the scalar values of the matrix. This is especially useful in applications where the
/// sparsity pattern is built ahead of the matrix values, or the same sparsity pattern is re-used
/// between different matrices. Finally, we can use `SparsityPattern` to encode adjacency
/// information in graphs.
///
/// # Format
///
/// The format is exactly the same as for the index data structures of CSR and CSC matrices.
/// This means that the sparsity pattern of an `m x n` sparse matrix with `nnz` non-zeros,
/// where in this case `m x n` does *not* mean `rows x columns`, but rather `majors x minors`,
/// is represented by the following two arrays:
///
/// - `major_offsets`, an array of integers with length `m + 1`.
/// - `minor_indices`, an array of integers with length `nnz`.
///
/// The invariants and relationship between `major_offsets` and `minor_indices` remain the same
/// as for `row_offsets` and `col_indices` in the [CSR](`crate::csr::CsrMatrix`) format
/// specification.
#[derive(Debug, Clone, PartialEq, Eq)]
// TODO: Make SparsityPattern parametrized by index type
// (need a solid abstraction for index types though)
pub struct SparsityPattern {
major_offsets: Vec<usize>,
minor_indices: Vec<usize>,
minor_dim: usize,
}
impl SparsityPattern {
/// Create a sparsity pattern of the given dimensions without explicitly stored entries.
pub fn zeros(major_dim: usize, minor_dim: usize) -> Self {
Self {
major_offsets: vec![0; major_dim + 1],
minor_indices: vec![],
minor_dim,
}
}
/// The offsets for the major dimension.
#[inline]
pub fn major_offsets(&self) -> &[usize] {
&self.major_offsets
}
/// The indices for the minor dimension.
#[inline]
pub fn minor_indices(&self) -> &[usize] {
&self.minor_indices
}
/// The number of major lanes in the pattern.
#[inline]
pub fn major_dim(&self) -> usize {
assert!(self.major_offsets.len() > 0);
self.major_offsets.len() - 1
}
/// The number of minor lanes in the pattern.
#[inline]
pub fn minor_dim(&self) -> usize {
self.minor_dim
}
/// The number of "non-zeros", i.e. explicitly stored entries in the pattern.
#[inline]
pub fn nnz(&self) -> usize {
self.minor_indices.len()
}
/// Get the lane at the given index.
///
/// Panics
/// ------
///
/// Panics if `major_index` is out of bounds.
#[inline]
pub fn lane(&self, major_index: usize) -> &[usize] {
self.get_lane(major_index).unwrap()
}
/// Get the lane at the given index, or `None` if out of bounds.
#[inline]
pub fn get_lane(&self, major_index: usize) -> Option<&[usize]> {
let offset_begin = *self.major_offsets().get(major_index)?;
let offset_end = *self.major_offsets().get(major_index + 1)?;
Some(&self.minor_indices()[offset_begin..offset_end])
}
/// Try to construct a sparsity pattern from the given dimensions, major offsets
/// and minor indices.
///
/// Returns an error if the data does not conform to the requirements.
pub fn try_from_offsets_and_indices(
major_dim: usize,
minor_dim: usize,
major_offsets: Vec<usize>,
minor_indices: Vec<usize>,
) -> Result<Self, SparsityPatternFormatError> {
use SparsityPatternFormatError::*;
if major_offsets.len() != major_dim + 1 {
return Err(InvalidOffsetArrayLength);
}
// Check that the first and last offsets conform to the specification
{
let first_offset_ok = *major_offsets.first().unwrap() == 0;
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
if !first_offset_ok || !last_offset_ok {
return Err(InvalidOffsetFirstLast);
}
}
// Test that each lane has strictly monotonically increasing minor indices, i.e.
// minor indices within a lane are sorted, unique. In addition, each minor index
// must be in bounds with respect to the minor dimension.
{
for lane_idx in 0..major_dim {
let range_start = major_offsets[lane_idx];
let range_end = major_offsets[lane_idx + 1];
// Test that major offsets are monotonically increasing
if range_start > range_end {
return Err(NonmonotonicOffsets);
}
let minor_indices = &minor_indices[range_start..range_end];
// We test for in-bounds, uniqueness and monotonicity at the same time
// to ensure that we only visit each minor index once
let mut iter = minor_indices.iter();
let mut prev = None;
while let Some(next) = iter.next().copied() {
if next >= minor_dim {
return Err(MinorIndexOutOfBounds);
}
if let Some(prev) = prev {
if prev > next {
return Err(NonmonotonicMinorIndices);
} else if prev == next {
return Err(DuplicateEntry);
}
}
prev = Some(next);
}
}
}
Ok(Self {
major_offsets,
minor_indices,
minor_dim,
})
}
/// An iterator over the explicitly stored "non-zero" entries (i, j).
///
/// The iteration happens in a lane-major fashion, meaning that the lane index i
/// increases monotonically, and the minor index j increases monotonically within each
/// lane i.
///
/// Examples
/// --------
///
/// ```
/// # use nalgebra_sparse::pattern::SparsityPattern;
/// let offsets = vec![0, 2, 3, 4];
/// let minor_indices = vec![0, 2, 1, 0];
/// let pattern = SparsityPattern::try_from_offsets_and_indices(3, 4, offsets, minor_indices)
/// .unwrap();
///
/// let entries: Vec<_> = pattern.entries().collect();
/// assert_eq!(entries, vec![(0, 0), (0, 2), (1, 1), (2, 0)]);
/// ```
///
pub fn entries(&self) -> SparsityPatternIter {
SparsityPatternIter::from_pattern(self)
}
/// Returns the raw offset and index data for the sparsity pattern.
///
/// Examples
/// --------
///
/// ```
/// # use nalgebra_sparse::pattern::SparsityPattern;
/// let offsets = vec![0, 2, 3, 4];
/// let minor_indices = vec![0, 2, 1, 0];
/// let pattern = SparsityPattern::try_from_offsets_and_indices(
/// 3,
/// 4,
/// offsets.clone(),
/// minor_indices.clone())
/// .unwrap();
/// let (offsets2, minor_indices2) = pattern.disassemble();
/// assert_eq!(offsets2, offsets);
/// assert_eq!(minor_indices2, minor_indices);
/// ```
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>) {
(self.major_offsets, self.minor_indices)
}
/// Computes the transpose of the sparsity pattern.
///
/// This is analogous to matrix transposition, i.e. an entry `(i, j)` becomes `(j, i)` in the
/// new pattern.
pub fn transpose(&self) -> Self {
// By using unit () values, we can use the same routines as for CSR/CSC matrices
let values = vec![(); self.nnz()];
let (new_offsets, new_indices, _) = transpose_cs(
self.major_dim(),
self.minor_dim(),
self.major_offsets(),
self.minor_indices(),
&values,
);
// TODO: Skip checks
Self::try_from_offsets_and_indices(
self.minor_dim(),
self.major_dim(),
new_offsets,
new_indices,
)
.expect("Internal error: Transpose should never fail.")
}
}
/// Error type for `SparsityPattern` format errors.
#[non_exhaustive]
#[derive(Debug, PartialEq, Eq)]
pub enum SparsityPatternFormatError {
/// Indicates an invalid number of offsets.
///
/// The number of offsets must be equal to (major_dim + 1).
InvalidOffsetArrayLength,
/// Indicates that the first or last entry in the offset array did not conform to
/// specifications.
///
/// The first entry must be 0, and the last entry must be exactly one greater than the
/// major dimension.
InvalidOffsetFirstLast,
/// Indicates that the major offsets are not monotonically increasing.
NonmonotonicOffsets,
/// One or more minor indices are out of bounds.
MinorIndexOutOfBounds,
/// One or more duplicate entries were detected.
///
/// Two entries are considered duplicates if they are part of the same major lane and have
/// the same minor index.
DuplicateEntry,
/// Indicates that minor indices are not monotonically increasing within each lane.
NonmonotonicMinorIndices,
}
impl From<SparsityPatternFormatError> for SparseFormatError {
fn from(err: SparsityPatternFormatError) -> Self {
use crate::SparseFormatErrorKind;
use crate::SparseFormatErrorKind::*;
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
use SparsityPatternFormatError::*;
match err {
InvalidOffsetArrayLength
| InvalidOffsetFirstLast
| NonmonotonicOffsets
| NonmonotonicMinorIndices => {
SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err))
}
MinorIndexOutOfBounds => {
SparseFormatError::from_kind_and_error(IndexOutOfBounds, Box::from(err))
}
PatternDuplicateEntry => SparseFormatError::from_kind_and_error(
#[allow(unused_qualifications)]
SparseFormatErrorKind::DuplicateEntry,
Box::from(err),
),
}
}
}
impl fmt::Display for SparsityPatternFormatError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SparsityPatternFormatError::InvalidOffsetArrayLength => {
write!(f, "Length of offset array is not equal to (major_dim + 1).")
}
SparsityPatternFormatError::InvalidOffsetFirstLast => {
write!(f, "First or last offset is incompatible with format.")
}
SparsityPatternFormatError::NonmonotonicOffsets => {
write!(f, "Offsets are not monotonically increasing.")
}
SparsityPatternFormatError::MinorIndexOutOfBounds => {
write!(f, "A minor index is out of bounds.")
}
SparsityPatternFormatError::DuplicateEntry => {
write!(f, "Input data contains duplicate entries.")
}
SparsityPatternFormatError::NonmonotonicMinorIndices => {
write!(
f,
"Minor indices are not monotonically increasing within each lane."
)
}
}
}
}
impl Error for SparsityPatternFormatError {}
/// Iterator type for iterating over entries in a sparsity pattern.
#[derive(Debug, Clone)]
pub struct SparsityPatternIter<'a> {
// See implementation of Iterator::next for an explanation of how these members are used
major_offsets: &'a [usize],
minor_indices: &'a [usize],
current_lane_idx: usize,
remaining_minors_in_lane: &'a [usize],
}
impl<'a> SparsityPatternIter<'a> {
fn from_pattern(pattern: &'a SparsityPattern) -> Self {
let first_lane_end = pattern.major_offsets().get(1).unwrap_or(&0);
let minors_in_first_lane = &pattern.minor_indices()[0..*first_lane_end];
Self {
major_offsets: pattern.major_offsets(),
minor_indices: pattern.minor_indices(),
current_lane_idx: 0,
remaining_minors_in_lane: minors_in_first_lane,
}
}
}
impl<'a> Iterator for SparsityPatternIter<'a> {
type Item = (usize, usize);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
// We ensure fast iteration across each lane by iteratively "draining" a slice
// corresponding to the remaining column indices in the particular lane.
// When we reach the end of this slice, we are at the end of a lane,
// and we must do some bookkeeping for preparing the iteration of the next lane
// (or stop iteration if we're through all lanes).
// This way we can avoid doing unnecessary bookkeeping on every iteration,
// instead paying a small price whenever we jump to a new lane.
if let Some(minor_idx) = self.remaining_minors_in_lane.first() {
let item = Some((self.current_lane_idx, *minor_idx));
self.remaining_minors_in_lane = &self.remaining_minors_in_lane[1..];
item
} else {
loop {
// Keep skipping lanes until we found a non-empty lane or there are no more lanes
if self.current_lane_idx + 2 >= self.major_offsets.len() {
// We've processed all lanes, so we're at the end of the iterator
// (note: keep in mind that offsets.len() == major_dim() + 1, hence we need +2)
return None;
} else {
// Bump lane index and check if the lane is non-empty
self.current_lane_idx += 1;
let lower = self.major_offsets[self.current_lane_idx];
let upper = self.major_offsets[self.current_lane_idx + 1];
if upper > lower {
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper];
return Some((self.current_lane_idx, self.minor_indices[lower]));
}
}
}
}
}
}

View File

@ -0,0 +1,374 @@
//! Functionality for integrating `nalgebra-sparse` with `proptest`.
//!
//! **This module is only available if the `proptest-support` feature is enabled**.
//!
//! The strategies provided here are generally expected to be able to generate the entire range
//! of possible outputs given the constraints on dimensions and values. However, there are no
//! particular guarantees on the distribution of possible values.
// Contains some patched code from proptest that we can remove in the (hopefully near) future.
// See docs in file for more details.
mod proptest_patched;
use crate::coo::CooMatrix;
use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use crate::pattern::SparsityPattern;
use nalgebra::proptest::DimRange;
use nalgebra::{Dim, Scalar};
use proptest::collection::{btree_set, hash_map, vec};
use proptest::prelude::*;
use proptest::sample::Index;
use std::cmp::min;
use std::iter::repeat;
fn dense_row_major_coord_strategy(
nrows: usize,
ncols: usize,
nnz: usize,
) -> impl Strategy<Value = Vec<(usize, usize)>> {
assert!(nnz <= nrows * ncols);
let mut booleans = vec![true; nnz];
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
// Make sure that exactly `nnz` of the booleans are true
// TODO: We cannot use the below code because of a bug in proptest, see
// https://github.com/AltSysrq/proptest/pull/217
// so for now we're using a patched version of the Shuffle adapter
// (see also docs in `proptest_patched`
// Just(booleans)
// // Need to shuffle to make sure they are randomly distributed
// .prop_shuffle()
proptest_patched::Shuffle(Just(booleans)).prop_map(move |booleans| {
booleans
.into_iter()
.enumerate()
.filter_map(|(index, is_entry)| {
if is_entry {
// Convert linear index to row/col pair
let i = index / ncols;
let j = index % ncols;
Some((i, j))
} else {
None
}
})
.collect::<Vec<_>>()
})
}
/// A strategy for generating `nnz` triplets.
///
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
fn dense_triplet_strategy<T>(
value_strategy: T,
nrows: usize,
ncols: usize,
nnz: usize,
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
assert!(nnz <= nrows * ncols);
// Construct a number of booleans of which exactly `nnz` are true.
let booleans: Vec<_> = repeat(true)
.take(nnz)
.chain(repeat(false))
.take(nrows * ncols)
.collect();
Just(booleans)
// Shuffle the booleans so that they are randomly distributed
.prop_shuffle()
// Convert the booleans into a list of coordinate pairs
.prop_map(move |booleans| {
booleans
.into_iter()
.enumerate()
.filter_map(|(index, is_entry)| {
if is_entry {
// Convert linear index to row/col pair
let i = index / ncols;
let j = index % ncols;
Some((i, j))
} else {
None
}
})
.collect::<Vec<_>>()
})
// Assign values to each coordinate pair in order to generate a list of triplets
.prop_flat_map(move |coords| {
vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
coords
.clone()
.into_iter()
.zip(values)
.map(|((i, j), v)| (i, j, v))
.collect::<Vec<_>>()
})
})
}
/// A strategy for generating `nnz` triplets.
///
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
/// close to `nrows * ncols` it may fail due to excessive rejected samples.
fn sparse_triplet_strategy<T>(
value_strategy: T,
nrows: usize,
ncols: usize,
nnz: usize,
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
// Have to handle the zero case: proptest doesn't like empty ranges (i.e. 0 .. 0)
let row_index_strategy = if nrows > 0 { 0..nrows } else { 0..1 };
let col_index_strategy = if ncols > 0 { 0..ncols } else { 0..1 };
let coord_strategy = (row_index_strategy, col_index_strategy);
hash_map(coord_strategy, value_strategy.clone(), nnz)
.prop_map(|hash_map| {
let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
triplets
})
// Although order in the hash map is unspecified, it's not necessarily *random*
// - or, in particular, it does not necessarily sample the whole space of possible outcomes -
// so we additionally shuffle the triplets
.prop_shuffle()
}
/// A strategy for producing COO matrices without duplicate entries.
///
/// The values of the matrix are picked from the provided `value_strategy`, while the size of the
/// generated matrices is determined by the ranges `rows` and `cols`. The number of explicitly
/// stored entries is bounded from above by `max_nonzeros`. Note that the matrix might still
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
pub fn coo_no_duplicates<T>(
value_strategy: T,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize,
) -> impl Strategy<Value = CooMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
(
rows.into().to_range_inclusive(),
cols.into().to_range_inclusive(),
)
.prop_flat_map(move |(nrows, ncols)| {
let max_nonzeros = min(max_nonzeros, nrows * ncols);
let size_range = 0..=max_nonzeros;
let value_strategy = value_strategy.clone();
size_range
.prop_flat_map(move |nnz| {
let value_strategy = value_strategy.clone();
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
// If the number of nnz is sufficiently dense, then use the dense
// sample strategy
dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
} else {
// Otherwise, use a hash map strategy so that we can get a sparse sampling
// (so that complexity is rather on the order of max_nnz than nrows * ncols)
sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
}
})
.prop_map(move |triplets| {
let mut coo = CooMatrix::new(nrows, ncols);
for (i, j, v) in triplets {
coo.push(i, j, v);
}
coo
})
})
}
/// A strategy for producing COO matrices with duplicate entries.
///
/// The values of the matrix are picked from the provided `value_strategy`, while the size of the
/// generated matrices is determined by the ranges `rows` and `cols`. Note that the values
/// only apply to individual entries, and since this strategy can generate duplicate entries,
/// the matrix will generally have values outside the range determined by `value_strategy` when
/// converted to other formats, since the duplicate entries are summed together in this case.
///
/// The number of explicitly stored entries is bounded from above by `max_nonzeros`. The maximum
/// number of duplicate entries is determined by `max_duplicates`. Note that the matrix might still
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
pub fn coo_with_duplicates<T>(
value_strategy: T,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize,
max_duplicates: usize,
) -> impl Strategy<Value = CooMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
let coo_strategy = coo_no_duplicates(value_strategy.clone(), rows, cols, max_nonzeros);
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
(coo_strategy, duplicate_strategy)
.prop_flat_map(|(coo, duplicates)| {
let mut triplets: Vec<(usize, usize, T::Value)> = coo
.triplet_iter()
.map(|(i, j, v)| (i, j, v.clone()))
.collect();
if !triplets.is_empty() {
let duplicates_iter: Vec<_> = duplicates
.into_iter()
.map(|(idx, val)| {
let (i, j, _) = idx.get(&triplets);
(*i, *j, val)
})
.collect();
triplets.extend(duplicates_iter);
}
// Make sure to shuffle so that the duplicates get mixed in with the non-duplicates
let shuffled = Just(triplets).prop_shuffle();
(Just(coo.nrows()), Just(coo.ncols()), shuffled)
})
.prop_map(move |(nrows, ncols, triplets)| {
let mut coo = CooMatrix::new(nrows, ncols);
for (i, j, v) in triplets {
coo.push(i, j, v);
}
coo
})
}
fn sparsity_pattern_from_row_major_coords<I>(
nmajor: usize,
nminor: usize,
coords: I,
) -> SparsityPattern
where
I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
{
let mut minors = Vec::with_capacity(coords.len());
let mut offsets = Vec::with_capacity(nmajor + 1);
let mut current_major = 0;
offsets.push(0);
for (idx, (i, j)) in coords.enumerate() {
assert!(i >= current_major);
assert!(
i < nmajor && j < nminor,
"Generated coords are out of bounds"
);
while current_major < i {
offsets.push(idx);
current_major += 1;
}
minors.push(j);
}
while current_major < nmajor {
offsets.push(minors.len());
current_major += 1;
}
assert_eq!(offsets.first().unwrap(), &0);
assert_eq!(offsets.len(), nmajor + 1);
SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
.expect("Internal error: Generated sparsity pattern is invalid")
}
/// A strategy for generating sparsity patterns.
pub fn sparsity_pattern(
major_lanes: impl Into<DimRange>,
minor_lanes: impl Into<DimRange>,
max_nonzeros: usize,
) -> impl Strategy<Value = SparsityPattern> {
(
major_lanes.into().to_range_inclusive(),
minor_lanes.into().to_range_inclusive(),
)
.prop_flat_map(move |(nmajor, nminor)| {
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
(Just(nmajor), Just(nminor), 0..=max_nonzeros)
})
.prop_flat_map(move |(nmajor, nminor, nnz)| {
if 10 * nnz < nmajor * nminor {
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
btree_set((0..nmajor, 0..nminor), nnz)
.prop_map(move |coords| {
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords.into_iter())
})
.boxed()
} else {
// If the required number of nonzeros is sufficiently dense,
// we instead use a dense sampling
dense_row_major_coord_strategy(nmajor, nminor, nnz)
.prop_map(move |coords| {
let coords = coords.into_iter();
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
})
.boxed()
}
})
}
/// A strategy for generating CSR matrices.
pub fn csr<T>(
value_strategy: T,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize,
) -> impl Strategy<Value = CsrMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
let rows = rows.into();
let cols = cols.into();
sparsity_pattern(
rows.lower_bound().value()..=rows.upper_bound().value(),
cols.lower_bound().value()..=cols.upper_bound().value(),
max_nonzeros,
)
.prop_flat_map(move |pattern| {
let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz];
(Just(pattern), values)
})
.prop_map(|(pattern, values)| {
CsrMatrix::try_from_pattern_and_values(pattern, values)
.expect("Internal error: Generated CsrMatrix is invalid")
})
}
/// A strategy for generating CSC matrices.
pub fn csc<T>(
value_strategy: T,
rows: impl Into<DimRange>,
cols: impl Into<DimRange>,
max_nonzeros: usize,
) -> impl Strategy<Value = CscMatrix<T::Value>>
where
T: Strategy + Clone + 'static,
T::Value: Scalar,
{
let rows = rows.into();
let cols = cols.into();
sparsity_pattern(
cols.lower_bound().value()..=cols.upper_bound().value(),
rows.lower_bound().value()..=rows.upper_bound().value(),
max_nonzeros,
)
.prop_flat_map(move |pattern| {
let nnz = pattern.nnz();
let values = vec![value_strategy.clone(); nnz];
(Just(pattern), values)
})
.prop_map(|(pattern, values)| {
CscMatrix::try_from_pattern_and_values(pattern, values)
.expect("Internal error: Generated CscMatrix is invalid")
})
}

View File

@ -0,0 +1,146 @@
//! Contains a modified implementation of `proptest::strategy::Shuffle`.
//!
//! The current implementation in `proptest` does not generate all permutations, which is
//! problematic for our proptest generators. The issue has been fixed in
//! https://github.com/AltSysrq/proptest/pull/217
//! but it has yet to be merged and released. As soon as this fix makes it into a new release,
//! the modified code here can be removed.
//!
/*!
This code has been copied and adapted from
https://github.com/AltSysrq/proptest/blob/master/proptest/src/strategy/shuffle.rs
The original licensing text is:
//-
// Copyright 2017 Jason Lingle
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
*/
use proptest::num;
use proptest::prelude::Rng;
use proptest::strategy::{NewTree, Shuffleable, Strategy, ValueTree};
use proptest::test_runner::{TestRng, TestRunner};
use std::cell::Cell;
#[derive(Clone, Debug)]
#[must_use = "strategies do nothing unless used"]
pub struct Shuffle<S>(pub(super) S);
impl<S: Strategy> Strategy for Shuffle<S>
where
S::Value: Shuffleable,
{
type Tree = ShuffleValueTree<S::Tree>;
type Value = S::Value;
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
let rng = runner.new_rng();
self.0.new_tree(runner).map(|inner| ShuffleValueTree {
inner,
rng,
dist: Cell::new(None),
simplifying_inner: false,
})
}
}
#[derive(Clone, Debug)]
pub struct ShuffleValueTree<V> {
inner: V,
rng: TestRng,
dist: Cell<Option<num::usize::BinarySearch>>,
simplifying_inner: bool,
}
impl<V: ValueTree> ShuffleValueTree<V>
where
V::Value: Shuffleable,
{
fn init_dist(&self, dflt: usize) -> usize {
if self.dist.get().is_none() {
self.dist.set(Some(num::usize::BinarySearch::new(dflt)));
}
self.dist.get().unwrap().current()
}
fn force_init_dist(&self) {
if self.dist.get().is_none() {
let _ = self.init_dist(self.current().shuffle_len());
}
}
}
impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
where
V::Value: Shuffleable,
{
type Value = V::Value;
fn current(&self) -> V::Value {
let mut value = self.inner.current();
let len = value.shuffle_len();
// The maximum distance to swap elements. This could be larger than
// `value` if `value` has reduced size during shrinking; that's OK,
// since we only use this to filter swaps.
let max_swap = self.init_dist(len);
// If empty collection or all swaps will be filtered out, there's
// nothing to shuffle.
if 0 == len || 0 == max_swap {
return value;
}
let mut rng = self.rng.clone();
for start_index in 0..len - 1 {
// Determine the other index to be swapped, then skip the swap if
// it is too far. This ordering is critical, as it ensures that we
// generate the same sequence of random numbers every time.
// NOTE: The below line is the whole reason for the existence of this adapted code
// We need to be able to swap with the same element, so that some elements remain in
// place rather being swapped
// let end_index = rng.gen_range(start_index + 1..len);
let end_index = rng.gen_range(start_index..len);
if end_index - start_index <= max_swap {
value.shuffle_swap(start_index, end_index);
}
}
value
}
fn simplify(&mut self) -> bool {
if self.simplifying_inner {
self.inner.simplify()
} else {
// Ensure that we've initialised `dist` to *something* to give
// consistent non-panicking behaviour even if called in an
// unexpected sequence.
self.force_init_dist();
if self.dist.get_mut().as_mut().unwrap().simplify() {
true
} else {
self.simplifying_inner = true;
self.inner.simplify()
}
}
}
fn complicate(&mut self) -> bool {
if self.simplifying_inner {
self.inner.complicate()
} else {
self.force_init_dist();
self.dist.get_mut().as_mut().unwrap().complicate()
}
}
}

View File

@ -0,0 +1,77 @@
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{csc, csr};
use proptest::strategy::Strategy;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::ops::RangeInclusive;
#[macro_export]
macro_rules! assert_panics {
($e:expr) => {{
use std::panic::catch_unwind;
use std::stringify;
let expr_string = stringify!($e);
// Note: We cannot manipulate the panic hook here, because it is global and the test
// suite is run in parallel, which leads to race conditions in the sense
// that some regular tests that panic might not output anything anymore.
// Unfortunately this means that output is still printed to stdout if
// we run cargo test -- --nocapture. But Cargo does not forward this if the test
// binary is not run with nocapture, so it is somewhat acceptable nonetheless.
let result = catch_unwind(|| $e);
if result.is_ok() {
panic!(
"assert_panics!({}) failed: the expression did not panic.",
expr_string
);
}
}};
}
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
pub const PROPTEST_MAX_NNZ: usize = 40;
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5..=5;
pub fn value_strategy<T>() -> RangeInclusive<T>
where
T: TryFrom<i32>,
T::Error: Debug,
{
let (start, end) = (
PROPTEST_I32_VALUE_STRATEGY.start(),
PROPTEST_I32_VALUE_STRATEGY.end(),
);
T::try_from(*start).unwrap()..=T::try_from(*end).unwrap()
}
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> {
let (start, end) = (
PROPTEST_I32_VALUE_STRATEGY.start(),
PROPTEST_I32_VALUE_STRATEGY.end(),
);
assert!(start < &0);
assert!(end > &0);
// Note: we don't use RangeInclusive for the second range, because then we'd have different
// types, which would require boxing
(*start..0).prop_union(1..*end + 1)
}
pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
csr(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
PROPTEST_MAX_NNZ,
)
}
pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
csc(
PROPTEST_I32_VALUE_STRATEGY,
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
PROPTEST_MAX_NNZ,
)
}

View File

@ -0,0 +1,8 @@
//! Unit tests
#[cfg(any(not(feature = "proptest-support"), not(feature = "compare")))]
compile_error!("Tests must be run with features `proptest-support` and `compare`");
mod unit_tests;
#[macro_use]
pub mod common;

View File

@ -0,0 +1,8 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc 3f71c8edc555965e521e3aaf58c736240a0e333c3a9d54e8a836d7768c371215 # shrinks to matrix = CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0], minor_indices: [], minor_dim: 0 }, values: [] } }
cc aef645e3184b814ef39fbb10234f12e6ff502ab515dabefafeedab5895e22b12 # shrinks to (matrix, rhs) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 4, 7, 11, 14], minor_indices: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 0, 2, 3], minor_dim: 4 }, values: [1.0, 0.0, 0.0, 0.0, 0.0, 40.90124126326177, 36.975170911665906, 0.0, 36.975170911665906, 42.51062858727923, -12.984115201530539, 0.0, -12.984115201530539, 27.73953543265418] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, -4.05763092330143], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 1 } } })

View File

@ -0,0 +1,117 @@
#![cfg_attr(rustfmt, rustfmt_skip)]
use crate::common::{value_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ};
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::factorization::{CscCholesky};
use nalgebra_sparse::proptest::csc;
use nalgebra::{Matrix5, Vector5, Cholesky, DMatrix};
use nalgebra::proptest::matrix;
use proptest::prelude::*;
use matrixcompare::{assert_matrix_eq, prop_assert_matrix_eq};
fn positive_definite() -> impl Strategy<Value=CscMatrix<f64>> {
let csc_f64 = csc(value_strategy::<f64>(),
PROPTEST_MATRIX_DIM,
PROPTEST_MATRIX_DIM,
PROPTEST_MAX_NNZ);
csc_f64
.prop_map(|x| {
// Add a small multiple of the identity to ensure positive definiteness
x.transpose() * &x + CscMatrix::identity(x.ncols())
})
}
proptest! {
#[test]
fn cholesky_correct_for_positive_definite_matrices(
matrix in positive_definite()
) {
let cholesky = CscCholesky::factor(&matrix).unwrap();
let l = cholesky.take_l();
let matrix_reconstructed = &l * l.transpose();
prop_assert_matrix_eq!(matrix_reconstructed, matrix, comp = abs, tol = 1e-8);
let is_lower_triangular = l.triplet_iter().all(|(i, j, _)| j <= i);
prop_assert!(is_lower_triangular);
}
#[test]
fn cholesky_solve_positive_definite(
(matrix, rhs) in positive_definite()
.prop_flat_map(|csc| {
let rhs = matrix(value_strategy::<f64>(), csc.nrows(), PROPTEST_MATRIX_DIM);
(Just(csc), rhs)
})
) {
let cholesky = CscCholesky::factor(&matrix).unwrap();
// solve_mut
{
let mut x = rhs.clone();
cholesky.solve_mut(&mut x);
prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12);
}
// solve
{
let x = cholesky.solve(&rhs);
prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12);
}
}
}
// This is a test ported from nalgebra's "sparse" module, for the original CsCholesky impl
#[test]
fn cs_cholesky() {
let mut a = Matrix5::new(
40.0, 0.0, 0.0, 0.0, 0.0,
2.0, 60.0, 0.0, 0.0, 0.0,
1.0, 0.0, 11.0, 0.0, 0.0,
0.0, 0.0, 0.0, 50.0, 0.0,
1.0, 0.0, 0.0, 4.0, 10.0
);
a.fill_upper_triangle_with_lower_triangle();
test_cholesky(a);
let a = Matrix5::from_diagonal(&Vector5::new(40.0, 60.0, 11.0, 50.0, 10.0));
test_cholesky(a);
let mut a = Matrix5::new(
40.0, 0.0, 0.0, 0.0, 0.0,
2.0, 60.0, 0.0, 0.0, 0.0,
1.0, 0.0, 11.0, 0.0, 0.0,
1.0, 0.0, 0.0, 50.0, 0.0,
0.0, 0.0, 0.0, 4.0, 10.0
);
a.fill_upper_triangle_with_lower_triangle();
test_cholesky(a);
let mut a = Matrix5::new(
2.0, 0.0, 0.0, 0.0, 0.0,
0.0, 2.0, 0.0, 0.0, 0.0,
1.0, 1.0, 2.0, 0.0, 0.0,
0.0, 0.0, 0.0, 2.0, 0.0,
1.0, 1.0, 0.0, 0.0, 2.0
);
a.fill_upper_triangle_with_lower_triangle();
// Test crate::new, left_looking, and up_looking implementations.
test_cholesky(a);
}
fn test_cholesky(a: Matrix5<f64>) {
// TODO: Test "refactor"
let cs_a = CscMatrix::from(&a);
let chol_a = Cholesky::new(a).unwrap();
let chol_cs_a = CscCholesky::factor(&cs_a).unwrap();
let l = chol_a.l();
let cs_l = chol_cs_a.take_l();
let l = DMatrix::from_iterator(l.nrows(), l.ncols(), l.iter().cloned());
let cs_l_mat = DMatrix::from(&cs_l);
assert_matrix_eq!(l, cs_l_mat, comp = abs, tol = 1e-12);
}

View File

@ -0,0 +1,10 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc 07cb95127d2700ff2000157938e351ce2b43f3e6419d69b00726abfc03e682bd # shrinks to coo = CooMatrix { nrows: 4, ncols: 5, row_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0], col_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 4, 3], values: [1, -5, -4, -5, 1, 2, 4, -4, -4, -5, 2, -2, 4, -4] }
cc 8fdaf70d6091d89a6617573547745e9802bb9c1ce7c6ec7ad4f301cd05d54c5d # shrinks to dense = Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }
cc 6961760ac7915b57a28230524cea7e9bfcea4f31790e3c0569ea74af904c2d79 # shrinks to coo = CooMatrix { nrows: 6, ncols: 6, row_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0], col_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0], values: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0] }
cc c9a1af218f7a974f1fda7b8909c2635d735eedbfe953082ef6b0b92702bf6d1b # shrinks to dense = Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 5 } } }

View File

@ -0,0 +1,452 @@
use crate::common::csc_strategy;
use nalgebra::proptest::matrix;
use nalgebra::DMatrix;
use nalgebra_sparse::convert::serial::{
convert_coo_csc, convert_coo_csr, convert_coo_dense, convert_csc_coo, convert_csc_csr,
convert_csc_dense, convert_csr_coo, convert_csr_csc, convert_csr_dense, convert_dense_coo,
convert_dense_csc, convert_dense_csr,
};
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::proptest::{coo_no_duplicates, coo_with_duplicates, csc, csr};
use proptest::prelude::*;
#[test]
fn test_convert_dense_coo() {
// No duplicates
{
#[rustfmt::skip]
let entries = &[1, 0, 3,
0, 5, 0];
// The COO representation of a dense matrix is not unique.
// Here we implicitly test that the coo matrix is indeed constructed from column-major
// iteration of the dense matrix.
let dense = DMatrix::from_row_slice(2, 3, entries);
let coo = CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
.unwrap();
assert_eq!(CooMatrix::from(&dense), coo);
assert_eq!(DMatrix::from(&coo), dense);
}
// Duplicates
// No duplicates
{
#[rustfmt::skip]
let entries = &[1, 0, 3,
0, 5, 0];
// The COO representation of a dense matrix is not unique.
// Here we implicitly test that the coo matrix is indeed constructed from column-major
// iteration of the dense matrix.
let dense = DMatrix::from_row_slice(2, 3, entries);
let coo_no_dup =
CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
.unwrap();
let coo_dup = CooMatrix::try_from_triplets(
2,
3,
vec![0, 1, 0, 1],
vec![0, 1, 2, 1],
vec![1, -2, 3, 7],
)
.unwrap();
assert_eq!(CooMatrix::from(&dense), coo_no_dup);
assert_eq!(DMatrix::from(&coo_dup), dense);
}
}
#[test]
fn test_convert_coo_csr() {
// No duplicates
{
let coo = {
let mut coo = CooMatrix::new(3, 4);
coo.push(1, 3, 4);
coo.push(0, 1, 2);
coo.push(2, 0, 1);
coo.push(2, 3, 2);
coo.push(2, 2, 1);
coo
};
let expected_csr = CsrMatrix::try_from_csr_data(
3,
4,
vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3],
vec![2, 4, 1, 1, 2],
)
.unwrap();
assert_eq!(convert_coo_csr(&coo), expected_csr);
}
// Duplicates
{
let coo = {
let mut coo = CooMatrix::new(3, 4);
coo.push(1, 3, 4);
coo.push(2, 3, 2);
coo.push(0, 1, 2);
coo.push(2, 0, 1);
coo.push(2, 3, 2);
coo.push(0, 1, 3);
coo.push(2, 2, 1);
coo
};
let expected_csr = CsrMatrix::try_from_csr_data(
3,
4,
vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4],
)
.unwrap();
assert_eq!(convert_coo_csr(&coo), expected_csr);
}
}
#[test]
fn test_convert_csr_coo() {
let csr = CsrMatrix::try_from_csr_data(
3,
4,
vec![0, 1, 2, 5],
vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4],
)
.unwrap();
let expected_coo = CooMatrix::try_from_triplets(
3,
4,
vec![0, 1, 2, 2, 2],
vec![1, 3, 0, 2, 3],
vec![5, 4, 1, 1, 4],
)
.unwrap();
assert_eq!(convert_csr_coo(&csr), expected_coo);
}
#[test]
fn test_convert_coo_csc() {
// No duplicates
{
let coo = {
let mut coo = CooMatrix::new(3, 4);
coo.push(1, 3, 4);
coo.push(0, 1, 2);
coo.push(2, 0, 1);
coo.push(2, 3, 2);
coo.push(2, 2, 1);
coo
};
let expected_csc = CscMatrix::try_from_csc_data(
3,
4,
vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2],
vec![1, 2, 1, 4, 2],
)
.unwrap();
assert_eq!(convert_coo_csc(&coo), expected_csc);
}
// Duplicates
{
let coo = {
let mut coo = CooMatrix::new(3, 4);
coo.push(1, 3, 4);
coo.push(2, 3, 2);
coo.push(0, 1, 2);
coo.push(2, 0, 1);
coo.push(2, 3, 2);
coo.push(0, 1, 3);
coo.push(2, 2, 1);
coo
};
let expected_csc = CscMatrix::try_from_csc_data(
3,
4,
vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2],
vec![1, 5, 1, 4, 4],
)
.unwrap();
assert_eq!(convert_coo_csc(&coo), expected_csc);
}
}
#[test]
fn test_convert_csc_coo() {
let csc = CscMatrix::try_from_csc_data(
3,
4,
vec![0, 1, 2, 3, 5],
vec![2, 0, 2, 1, 2],
vec![1, 2, 1, 4, 2],
)
.unwrap();
let expected_coo = CooMatrix::try_from_triplets(
3,
4,
vec![2, 0, 2, 1, 2],
vec![0, 1, 2, 3, 3],
vec![1, 2, 1, 4, 2],
)
.unwrap();
assert_eq!(convert_csc_coo(&csc), expected_coo);
}
#[test]
fn test_convert_csr_csc_bidirectional() {
let csr = CsrMatrix::try_from_csr_data(
3,
4,
vec![0, 3, 4, 6],
vec![1, 2, 3, 0, 1, 3],
vec![5, 3, 2, 2, 1, 4],
)
.unwrap();
let csc = CscMatrix::try_from_csc_data(
3,
4,
vec![0, 1, 3, 4, 6],
vec![1, 0, 2, 0, 0, 2],
vec![2, 5, 1, 3, 2, 4],
)
.unwrap();
assert_eq!(convert_csr_csc(&csr), csc);
assert_eq!(convert_csc_csr(&csc), csr);
}
#[test]
fn test_convert_csr_dense_bidirectional() {
let csr = CsrMatrix::try_from_csr_data(
3,
4,
vec![0, 3, 4, 6],
vec![1, 2, 3, 0, 1, 3],
vec![5, 3, 2, 2, 1, 4],
)
.unwrap();
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(3, 4, &[
0, 5, 3, 2,
2, 0, 0, 0,
0, 1, 0, 4
]);
assert_eq!(convert_csr_dense(&csr), dense);
assert_eq!(convert_dense_csr(&dense), csr);
}
#[test]
fn test_convert_csc_dense_bidirectional() {
let csc = CscMatrix::try_from_csc_data(
3,
4,
vec![0, 1, 3, 4, 6],
vec![1, 0, 2, 0, 0, 2],
vec![2, 5, 1, 3, 2, 4],
)
.unwrap();
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(3, 4, &[
0, 5, 3, 2,
2, 0, 0, 0,
0, 1, 0, 4
]);
assert_eq!(convert_csc_dense(&csc), dense);
assert_eq!(convert_dense_csc(&dense), csc);
}
fn coo_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
coo_with_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40, 2)
}
fn coo_no_duplicates_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
coo_no_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40)
}
fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
csr(-5..=5, 0..=6usize, 0..=6usize, 40)
}
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
fn non_zero_csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
csr(1..=5, 0..=6usize, 0..=6usize, 40)
}
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
fn non_zero_csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
csc(1..=5, 0..=6usize, 0..=6usize, 40)
}
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
matrix(-5..=5, 0..=6, 0..=6)
}
proptest! {
#[test]
fn convert_dense_coo_roundtrip(dense in matrix(-5 ..= 5, 0 ..=6, 0..=6)) {
let coo = convert_dense_coo(&dense);
let dense2 = convert_coo_dense(&coo);
prop_assert_eq!(&dense, &dense2);
}
#[test]
fn convert_coo_dense_coo_roundtrip(coo in coo_strategy()) {
// We cannot compare the result of the roundtrip coo -> dense -> coo directly for
// two reasons:
// 1. the COO matrices will generally have different ordering of elements
// 2. explicitly stored zero entries in the original matrix will be discarded
// when converting back to COO
// Therefore we instead compare the results of converting the COO matrix
// at the end of the roundtrip with its dense representation
let dense = convert_coo_dense(&coo);
let coo2 = convert_dense_coo(&dense);
let dense2 = convert_coo_dense(&coo2);
prop_assert_eq!(dense, dense2);
}
#[test]
fn coo_from_dense_roundtrip(dense in dense_strategy()) {
prop_assert_eq!(&dense, &DMatrix::from(&CooMatrix::from(&dense)));
}
#[test]
fn convert_coo_csr_agrees_with_csr_dense(coo in coo_strategy()) {
let coo_dense = convert_coo_dense(&coo);
let csr = convert_coo_csr(&coo);
let csr_dense = convert_csr_dense(&csr);
prop_assert_eq!(csr_dense, coo_dense);
// It might be that COO matrices have a higher nnz due to duplicates,
// so we can only check that the CSR matrix has no more than the original COO matrix
prop_assert!(csr.nnz() <= coo.nnz());
}
#[test]
fn convert_coo_csr_nnz(coo in coo_no_duplicates_strategy()) {
// Check that the NNZ are equal when converting from a CooMatrix without
// duplicates to a CSR matrix
let csr = convert_coo_csr(&coo);
prop_assert_eq!(csr.nnz(), coo.nnz());
}
#[test]
fn convert_csr_coo_roundtrip(csr in csr_strategy()) {
let coo = convert_csr_coo(&csr);
let csr2 = convert_coo_csr(&coo);
prop_assert_eq!(csr2, csr);
}
#[test]
fn coo_from_csr_roundtrip(csr in csr_strategy()) {
prop_assert_eq!(&csr, &CsrMatrix::from(&CooMatrix::from(&csr)));
}
#[test]
fn csr_from_dense_roundtrip(dense in dense_strategy()) {
prop_assert_eq!(&dense, &DMatrix::from(&CsrMatrix::from(&dense)));
}
#[test]
fn convert_csr_dense_roundtrip(csr in non_zero_csr_strategy()) {
// Since we only generate CSR matrices with non-zero values, we know that the
// number of explicitly stored entries when converting CSR->Dense->CSR should be
// unchanged, so that we can verify that the result is the same as the input
let dense = convert_csr_dense(&csr);
let csr2 = convert_dense_csr(&dense);
prop_assert_eq!(csr2, csr);
}
#[test]
fn convert_csc_coo_roundtrip(csc in csc_strategy()) {
let coo = convert_csc_coo(&csc);
let csc2 = convert_coo_csc(&coo);
prop_assert_eq!(csc2, csc);
}
#[test]
fn coo_from_csc_roundtrip(csc in csc_strategy()) {
prop_assert_eq!(&csc, &CscMatrix::from(&CooMatrix::from(&csc)));
}
#[test]
fn convert_csc_dense_roundtrip(csc in non_zero_csc_strategy()) {
// Since we only generate CSC matrices with non-zero values, we know that the
// number of explicitly stored entries when converting CSC->Dense->CSC should be
// unchanged, so that we can verify that the result is the same as the input
let dense = convert_csc_dense(&csc);
let csc2 = convert_dense_csc(&dense);
prop_assert_eq!(csc2, csc);
}
#[test]
fn csc_from_dense_roundtrip(dense in dense_strategy()) {
prop_assert_eq!(&dense, &DMatrix::from(&CscMatrix::from(&dense)));
}
#[test]
fn convert_coo_csc_agrees_with_csc_dense(coo in coo_strategy()) {
let coo_dense = convert_coo_dense(&coo);
let csc = convert_coo_csc(&coo);
let csc_dense = convert_csc_dense(&csc);
prop_assert_eq!(csc_dense, coo_dense);
// It might be that COO matrices have a higher nnz due to duplicates,
// so we can only check that the CSR matrix has no more than the original COO matrix
prop_assert!(csc.nnz() <= coo.nnz());
}
#[test]
fn convert_coo_csc_nnz(coo in coo_no_duplicates_strategy()) {
// Check that the NNZ are equal when converting from a CooMatrix without
// duplicates to a CSR matrix
let csc = convert_coo_csc(&coo);
prop_assert_eq!(csc.nnz(), coo.nnz());
}
#[test]
fn convert_csc_csr_roundtrip(csc in csc_strategy()) {
let csr = convert_csc_csr(&csc);
let csc2 = convert_csr_csc(&csr);
prop_assert_eq!(csc2, csc);
}
#[test]
fn convert_csr_csc_roundtrip(csr in csr_strategy()) {
let csc = convert_csr_csc(&csr);
let csr2 = convert_csc_csr(&csc);
prop_assert_eq!(csr2, csr);
}
#[test]
fn csc_from_csr_roundtrip(csr in csr_strategy()) {
prop_assert_eq!(&csr, &CsrMatrix::from(&CscMatrix::from(&csr)));
}
#[test]
fn csr_from_csc_roundtrip(csc in csc_strategy()) {
prop_assert_eq!(&csc, &CscMatrix::from(&CsrMatrix::from(&csc)));
}
}

View File

@ -0,0 +1,254 @@
use crate::assert_panics;
use nalgebra::DMatrix;
use nalgebra_sparse::coo::CooMatrix;
use nalgebra_sparse::SparseFormatErrorKind;
#[test]
fn coo_construction_for_valid_data() {
// Test that construction with try_from_triplets succeeds, that the state of the
// matrix afterwards is as expected, and that the dense representation matches expectations.
{
// Zero matrix
let coo =
CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 2);
assert!(coo.triplet_iter().next().is_none());
assert!(coo.row_indices().is_empty());
assert!(coo.col_indices().is_empty());
assert!(coo.values().is_empty());
assert_eq!(DMatrix::from(&coo), DMatrix::repeat(3, 2, 0));
}
{
// Arbitrary matrix, no duplicates
let i = vec![0, 1, 0, 0, 2];
let j = vec![0, 2, 1, 3, 3];
let v = vec![2, 3, 7, 3, 1];
let coo =
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 5);
assert_eq!(i.as_slice(), coo.row_indices());
assert_eq!(j.as_slice(), coo.col_indices());
assert_eq!(v.as_slice(), coo.values());
let expected_triplets: Vec<_> = i
.iter()
.zip(&j)
.zip(&v)
.map(|((i, j), v)| (*i, *j, *v))
.collect();
let actual_triplets: Vec<_> = coo.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
assert_eq!(actual_triplets, expected_triplets);
#[rustfmt::skip]
let expected_dense = DMatrix::from_row_slice(3, 5, &[
2, 7, 0, 3, 0,
0, 0, 3, 0, 0,
0, 0, 0, 1, 0
]);
assert_eq!(DMatrix::from(&coo), expected_dense);
}
{
// Arbitrary matrix, with duplicates
let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
let coo =
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
assert_eq!(coo.nrows(), 3);
assert_eq!(coo.ncols(), 5);
assert_eq!(i.as_slice(), coo.row_indices());
assert_eq!(j.as_slice(), coo.col_indices());
assert_eq!(v.as_slice(), coo.values());
let expected_triplets: Vec<_> = i
.iter()
.zip(&j)
.zip(&v)
.map(|((i, j), v)| (*i, *j, *v))
.collect();
let actual_triplets: Vec<_> = coo.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
assert_eq!(actual_triplets, expected_triplets);
#[rustfmt::skip]
let expected_dense = DMatrix::from_row_slice(3, 5, &[
7, 7, 0, 3, 0,
0, 0, 8, 0, 0,
0, 0, 0, 1, 0
]);
assert_eq!(DMatrix::from(&coo), expected_dense);
}
}
#[test]
fn coo_try_from_triplets_reports_out_of_bounds_indices() {
{
// 0x0 matrix
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// 1x1 matrix, row out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// 1x1 matrix, col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// 1x1 matrix, row and col out of bounds
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// Arbitrary matrix, row out of bounds
let i = vec![0, 1, 0, 3, 2];
let j = vec![0, 2, 1, 3, 3];
let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
{
// Arbitrary matrix, col out of bounds
let i = vec![0, 1, 0, 0, 2];
let j = vec![0, 2, 1, 5, 3];
let v = vec![2, 3, 7, 3, 1];
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
assert!(matches!(
result.unwrap_err().kind(),
SparseFormatErrorKind::IndexOutOfBounds
));
}
}
#[test]
fn coo_try_from_triplets_panics_on_mismatched_vectors() {
// Check that try_from_triplets panics when the triplet vectors have different lengths
macro_rules! assert_errs {
($result:expr) => {
assert!(matches!(
$result.unwrap_err().kind(),
SparseFormatErrorKind::InvalidStructure
))
};
}
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 2],
vec![0],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0, 0],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0],
vec![0, 1]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 2],
vec![0, 1],
vec![0]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1],
vec![0, 1],
vec![0, 1]
));
assert_errs!(CooMatrix::<i32>::try_from_triplets(
3,
5,
vec![1, 1],
vec![0],
vec![0, 1]
));
}
#[test]
fn coo_push_valid_entries() {
let mut coo = CooMatrix::new(3, 3);
coo.push(0, 0, 1);
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
coo.push(0, 0, 2);
assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![(0, 0, &1), (0, 0, &2)]
);
coo.push(2, 2, 3);
assert_eq!(
coo.triplet_iter().collect::<Vec<_>>(),
vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]
);
}
#[test]
fn coo_push_out_of_bounds_entries() {
{
// 0x0 matrix
let coo = CooMatrix::new(0, 0);
assert_panics!(coo.clone().push(0, 0, 1));
}
{
// 0x1 matrix
assert_panics!(CooMatrix::new(0, 1).push(0, 0, 1));
}
{
// 1x0 matrix
assert_panics!(CooMatrix::new(1, 0).push(0, 0, 1));
}
{
// Arbitrary matrix dimensions
let coo = CooMatrix::new(3, 2);
assert_panics!(coo.clone().push(3, 0, 1));
assert_panics!(coo.clone().push(2, 2, 1));
assert_panics!(coo.clone().push(3, 2, 1));
}
}

View File

@ -0,0 +1,7 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc a71b4654827840ed539b82cd7083615b0fb3f75933de6a7d91d8148a2bf34960 # shrinks to (csc, triplet_subset) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 1, 1, 1, 1, 1, 1], minor_indices: [0], minor_dim: 4 }, values: [0] } }, {})

View File

@ -0,0 +1,605 @@
use nalgebra::DMatrix;
use nalgebra_sparse::csc::CscMatrix;
use nalgebra_sparse::{SparseEntry, SparseEntryMut, SparseFormatErrorKind};
use proptest::prelude::*;
use proptest::sample::subsequence;
use crate::assert_panics;
use crate::common::csc_strategy;
use std::collections::HashSet;
#[test]
fn csc_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results
// that agree with expectations.
{
// A CSC matrix with zero explicitly stored entries
let offsets = vec![0, 0, 0, 0];
let indices = vec![];
let values = Vec::<i32>::new();
let mut matrix = CscMatrix::try_from_csc_data(2, 3, offsets, indices, values).unwrap();
assert_eq!(matrix, CscMatrix::zeros(2, 3));
assert_eq!(matrix.nrows(), 2);
assert_eq!(matrix.ncols(), 3);
assert_eq!(matrix.nnz(), 0);
assert_eq!(matrix.col_offsets(), &[0, 0, 0, 0]);
assert_eq!(matrix.row_indices(), &[]);
assert_eq!(matrix.values(), &[]);
assert!(matrix.triplet_iter().next().is_none());
assert!(matrix.triplet_iter_mut().next().is_none());
assert_eq!(matrix.col(0).nrows(), 2);
assert_eq!(matrix.col(0).nnz(), 0);
assert_eq!(matrix.col(0).row_indices(), &[]);
assert_eq!(matrix.col(0).values(), &[]);
assert_eq!(matrix.col_mut(0).nrows(), 2);
assert_eq!(matrix.col_mut(0).nnz(), 0);
assert_eq!(matrix.col_mut(0).row_indices(), &[]);
assert_eq!(matrix.col_mut(0).values(), &[]);
assert_eq!(matrix.col_mut(0).values_mut(), &[]);
assert_eq!(
matrix.col_mut(0).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(1).nrows(), 2);
assert_eq!(matrix.col(1).nnz(), 0);
assert_eq!(matrix.col(1).row_indices(), &[]);
assert_eq!(matrix.col(1).values(), &[]);
assert_eq!(matrix.col_mut(1).nrows(), 2);
assert_eq!(matrix.col_mut(1).nnz(), 0);
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
assert_eq!(matrix.col_mut(1).values(), &[]);
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
assert_eq!(
matrix.col_mut(1).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(2).nrows(), 2);
assert_eq!(matrix.col(2).nnz(), 0);
assert_eq!(matrix.col(2).row_indices(), &[]);
assert_eq!(matrix.col(2).values(), &[]);
assert_eq!(matrix.col_mut(2).nrows(), 2);
assert_eq!(matrix.col_mut(2).nnz(), 0);
assert_eq!(matrix.col_mut(2).row_indices(), &[]);
assert_eq!(matrix.col_mut(2).values(), &[]);
assert_eq!(matrix.col_mut(2).values_mut(), &[]);
assert_eq!(
matrix.col_mut(2).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert!(matrix.get_col(3).is_none());
assert!(matrix.get_col_mut(3).is_none());
let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets, vec![0, 0, 0, 0]);
assert_eq!(indices, vec![]);
assert_eq!(values, vec![]);
}
{
// An arbitrary CSC matrix
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let mut matrix =
CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
.unwrap();
assert_eq!(matrix.nrows(), 6);
assert_eq!(matrix.ncols(), 3);
assert_eq!(matrix.nnz(), 5);
assert_eq!(matrix.col_offsets(), &[0, 2, 2, 5]);
assert_eq!(matrix.row_indices(), &[0, 5, 1, 2, 3]);
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
assert_eq!(
matrix
.triplet_iter()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(
matrix
.triplet_iter_mut()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(matrix.col(0).nrows(), 6);
assert_eq!(matrix.col(0).nnz(), 2);
assert_eq!(matrix.col(0).row_indices(), &[0, 5]);
assert_eq!(matrix.col(0).values(), &[0, 1]);
assert_eq!(matrix.col_mut(0).nrows(), 6);
assert_eq!(matrix.col_mut(0).nnz(), 2);
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
assert_eq!(
matrix.col_mut(0).rows_and_values_mut(),
([0, 5].as_ref(), [0, 1].as_mut())
);
assert_eq!(matrix.col(1).nrows(), 6);
assert_eq!(matrix.col(1).nnz(), 0);
assert_eq!(matrix.col(1).row_indices(), &[]);
assert_eq!(matrix.col(1).values(), &[]);
assert_eq!(matrix.col_mut(1).nrows(), 6);
assert_eq!(matrix.col_mut(1).nnz(), 0);
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
assert_eq!(matrix.col_mut(1).values(), &[]);
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
assert_eq!(
matrix.col_mut(1).rows_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.col(2).nrows(), 6);
assert_eq!(matrix.col(2).nnz(), 3);
assert_eq!(matrix.col(2).row_indices(), &[1, 2, 3]);
assert_eq!(matrix.col(2).values(), &[2, 3, 4]);
assert_eq!(matrix.col_mut(2).nrows(), 6);
assert_eq!(matrix.col_mut(2).nnz(), 3);
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
assert_eq!(
matrix.col_mut(2).rows_and_values_mut(),
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
);
assert!(matrix.get_col(3).is_none());
assert!(matrix.get_col_mut(3).is_none());
let (offsets2, indices2, values2) = matrix.disassemble();
assert_eq!(offsets2, offsets);
assert_eq!(indices2, indices);
assert_eq!(values2, values);
}
}
#[test]
fn csc_matrix_try_from_invalid_csc_data() {
{
// Empty offset array (invalid length)
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Offset array invalid length for arbitrary data
let offsets = vec![0, 3, 5];
let indices = vec![0, 1, 2, 3, 5];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid first entry in offsets array
let offsets = vec![1, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid last entry in offsets array
let offsets = vec![0, 2, 2, 4];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid length of offsets array
let offsets = vec![0, 2, 2];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Nonmonotonic offsets
let offsets = vec![0, 3, 2, 5];
let indices = vec![0, 1, 2, 3, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Nonmonotonic minor indices
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 2, 3, 1, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Minor index out of bounds
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 6, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::IndexOutOfBounds
);
}
{
// Duplicate entry
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 2, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::DuplicateEntry
);
}
}
#[test]
fn csc_disassemble_avoids_clone_when_owned() {
// Test that disassemble avoids cloning the sparsity pattern when it holds the sole reference
// to the pattern. We do so by checking that the pointer to the data is unchanged.
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let offsets_ptr = offsets.as_ptr();
let indices_ptr = indices.as_ptr();
let values_ptr = values.as_ptr();
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values).unwrap();
let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets.as_ptr(), offsets_ptr);
assert_eq!(indices.as_ptr(), indices_ptr);
assert_eq!(values.as_ptr(), values_ptr);
}
// Rustfmt makes this test much harder to read by expanding some of the one-liners to 4-liners,
// so for now we skip rustfmt...
#[rustfmt::skip]
#[test]
fn csc_matrix_get_index_entry() {
// Test .get_entry(_mut) and .index_entry(_mut) methods
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(2, 3, &[
1, 0, 3,
0, 5, 6
]);
let csc = CscMatrix::from(&dense);
assert_eq!(csc.get_entry(0, 0), Some(SparseEntry::NonZero(&1)));
assert_eq!(csc.index_entry(0, 0), SparseEntry::NonZero(&1));
assert_eq!(csc.get_entry(0, 1), Some(SparseEntry::Zero));
assert_eq!(csc.index_entry(0, 1), SparseEntry::Zero);
assert_eq!(csc.get_entry(0, 2), Some(SparseEntry::NonZero(&3)));
assert_eq!(csc.index_entry(0, 2), SparseEntry::NonZero(&3));
assert_eq!(csc.get_entry(1, 0), Some(SparseEntry::Zero));
assert_eq!(csc.index_entry(1, 0), SparseEntry::Zero);
assert_eq!(csc.get_entry(1, 1), Some(SparseEntry::NonZero(&5)));
assert_eq!(csc.index_entry(1, 1), SparseEntry::NonZero(&5));
assert_eq!(csc.get_entry(1, 2), Some(SparseEntry::NonZero(&6)));
assert_eq!(csc.index_entry(1, 2), SparseEntry::NonZero(&6));
// Check some out of bounds with .get_entry
assert_eq!(csc.get_entry(0, 3), None);
assert_eq!(csc.get_entry(0, 4), None);
assert_eq!(csc.get_entry(1, 3), None);
assert_eq!(csc.get_entry(1, 4), None);
assert_eq!(csc.get_entry(2, 0), None);
assert_eq!(csc.get_entry(2, 1), None);
assert_eq!(csc.get_entry(2, 2), None);
assert_eq!(csc.get_entry(2, 3), None);
assert_eq!(csc.get_entry(2, 4), None);
// Check that out of bounds with .index_entry panics
assert_panics!(csc.index_entry(0, 3));
assert_panics!(csc.index_entry(0, 4));
assert_panics!(csc.index_entry(1, 3));
assert_panics!(csc.index_entry(1, 4));
assert_panics!(csc.index_entry(2, 0));
assert_panics!(csc.index_entry(2, 1));
assert_panics!(csc.index_entry(2, 2));
assert_panics!(csc.index_entry(2, 3));
assert_panics!(csc.index_entry(2, 4));
{
// Check mutable versions of the above functions
let mut csc = csc;
assert_eq!(csc.get_entry_mut(0, 0), Some(SparseEntryMut::NonZero(&mut 1)));
assert_eq!(csc.index_entry_mut(0, 0), SparseEntryMut::NonZero(&mut 1));
assert_eq!(csc.get_entry_mut(0, 1), Some(SparseEntryMut::Zero));
assert_eq!(csc.index_entry_mut(0, 1), SparseEntryMut::Zero);
assert_eq!(csc.get_entry_mut(0, 2), Some(SparseEntryMut::NonZero(&mut 3)));
assert_eq!(csc.index_entry_mut(0, 2), SparseEntryMut::NonZero(&mut 3));
assert_eq!(csc.get_entry_mut(1, 0), Some(SparseEntryMut::Zero));
assert_eq!(csc.index_entry_mut(1, 0), SparseEntryMut::Zero);
assert_eq!(csc.get_entry_mut(1, 1), Some(SparseEntryMut::NonZero(&mut 5)));
assert_eq!(csc.index_entry_mut(1, 1), SparseEntryMut::NonZero(&mut 5));
assert_eq!(csc.get_entry_mut(1, 2), Some(SparseEntryMut::NonZero(&mut 6)));
assert_eq!(csc.index_entry_mut(1, 2), SparseEntryMut::NonZero(&mut 6));
// Check some out of bounds with .get_entry_mut
assert_eq!(csc.get_entry_mut(0, 3), None);
assert_eq!(csc.get_entry_mut(0, 4), None);
assert_eq!(csc.get_entry_mut(1, 3), None);
assert_eq!(csc.get_entry_mut(1, 4), None);
assert_eq!(csc.get_entry_mut(2, 0), None);
assert_eq!(csc.get_entry_mut(2, 1), None);
assert_eq!(csc.get_entry_mut(2, 2), None);
assert_eq!(csc.get_entry_mut(2, 3), None);
assert_eq!(csc.get_entry_mut(2, 4), None);
// Check that out of bounds with .index_entry_mut panics
// Note: the cloning is necessary because a mutable reference is not UnwindSafe
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(0, 3); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(0, 4); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(1, 3); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(1, 4); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 0); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 1); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 2); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 3); });
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 4); });
}
}
#[test]
fn csc_matrix_col_iter() {
// Note: this is the transpose of the matrix used for the similar csr_matrix_row_iter test
// (this way the actual tests are almost identical, due to the transposed relationship
// between CSR and CSC)
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(4, 3, &[
0, 3, 0,
1, 0, 4,
2, 0, 0,
0, 0, 5,
]);
let csc = CscMatrix::from(&dense);
// Immutable iterator
{
let mut col_iter = csc.col_iter();
{
let col = col_iter.next().unwrap();
assert_eq!(col.nrows(), 4);
assert_eq!(col.nnz(), 2);
assert_eq!(col.row_indices(), &[1, 2]);
assert_eq!(col.values(), &[1, 2]);
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&1)));
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(4), None);
}
{
let col = col_iter.next().unwrap();
assert_eq!(col.nrows(), 4);
assert_eq!(col.nnz(), 1);
assert_eq!(col.row_indices(), &[0]);
assert_eq!(col.values(), &[3]);
assert_eq!(col.get_entry(0), Some(SparseEntry::NonZero(&3)));
assert_eq!(col.get_entry(1), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(4), None);
}
{
let col = col_iter.next().unwrap();
assert_eq!(col.nrows(), 4);
assert_eq!(col.nnz(), 2);
assert_eq!(col.row_indices(), &[1, 3]);
assert_eq!(col.values(), &[4, 5]);
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&4)));
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
assert_eq!(col.get_entry(4), None);
}
assert!(col_iter.next().is_none());
}
// Mutable iterator
{
let mut csc = csc;
let mut col_iter = csc.col_iter_mut();
{
let mut col = col_iter.next().unwrap();
assert_eq!(col.nrows(), 4);
assert_eq!(col.nnz(), 2);
assert_eq!(col.row_indices(), &[1, 2]);
assert_eq!(col.values(), &[1, 2]);
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&1)));
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(4), None);
assert_eq!(col.values_mut(), &mut [1, 2]);
assert_eq!(
col.rows_and_values_mut(),
([1, 2].as_ref(), [1, 2].as_mut())
);
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 1)));
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(4), None);
}
{
let mut col = col_iter.next().unwrap();
assert_eq!(col.nrows(), 4);
assert_eq!(col.nnz(), 1);
assert_eq!(col.row_indices(), &[0]);
assert_eq!(col.values(), &[3]);
assert_eq!(col.get_entry(0), Some(SparseEntry::NonZero(&3)));
assert_eq!(col.get_entry(1), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(4), None);
assert_eq!(col.values_mut(), &mut [3]);
assert_eq!(col.rows_and_values_mut(), ([0].as_ref(), [3].as_mut()));
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::NonZero(&mut 3)));
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(4), None);
}
{
let mut col = col_iter.next().unwrap();
assert_eq!(col.nrows(), 4);
assert_eq!(col.nnz(), 2);
assert_eq!(col.row_indices(), &[1, 3]);
assert_eq!(col.values(), &[4, 5]);
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&4)));
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
assert_eq!(col.get_entry(4), None);
assert_eq!(col.values_mut(), &mut [4, 5]);
assert_eq!(
col.rows_and_values_mut(),
([1, 3].as_ref(), [4, 5].as_mut())
);
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 4)));
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
assert_eq!(col.get_entry_mut(4), None);
}
assert!(col_iter.next().is_none());
}
}
proptest! {
#[test]
fn csc_double_transpose_is_identity(csc in csc_strategy()) {
prop_assert_eq!(csc.transpose().transpose(), csc);
}
#[test]
fn csc_transpose_agrees_with_dense(csc in csc_strategy()) {
let dense_transpose = DMatrix::from(&csc).transpose();
let csc_transpose = csc.transpose();
prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose));
prop_assert_eq!(csc.nnz(), csc_transpose.nnz());
}
#[test]
fn csc_filter(
(csc, triplet_subset)
in csc_strategy()
.prop_flat_map(|matrix| {
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
let subset = subsequence(triplets, 0 ..= matrix.nnz())
.prop_map(|triplet_subset| {
let set: HashSet<_> = triplet_subset.into_iter().collect();
set
});
(Just(matrix), subset)
}))
{
// We generate a CscMatrix and a HashSet corresponding to a subset of the (i, j, v)
// values in the matrix, which we use for filtering the matrix entries.
// The resulting triplets in the filtered matrix must then be exactly equal to
// the subset.
let filtered = csc.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
let filtered_triplets: HashSet<_> = filtered
.triplet_iter()
.cloned_values()
.collect();
prop_assert_eq!(filtered_triplets, triplet_subset);
}
#[test]
fn csc_lower_triangle_agrees_with_dense(csc in csc_strategy()) {
let csc_lower_triangle = csc.lower_triangle();
prop_assert_eq!(DMatrix::from(&csc_lower_triangle), DMatrix::from(&csc).lower_triangle());
prop_assert!(csc_lower_triangle.nnz() <= csc.nnz());
}
#[test]
fn csc_upper_triangle_agrees_with_dense(csc in csc_strategy()) {
let csc_upper_triangle = csc.upper_triangle();
prop_assert_eq!(DMatrix::from(&csc_upper_triangle), DMatrix::from(&csc).upper_triangle());
prop_assert!(csc_upper_triangle.nnz() <= csc.nnz());
}
#[test]
fn csc_diagonal_as_csc(csc in csc_strategy()) {
let d = csc.diagonal_as_csc();
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
let csc_diagonal_entries: HashSet<_> = csc
.triplet_iter()
.cloned_values()
.filter(|&(i, j, _)| i == j)
.collect();
prop_assert_eq!(d_entries, csc_diagonal_entries);
}
#[test]
fn csc_identity(n in 0 ..= 6usize) {
let csc = CscMatrix::<i32>::identity(n);
prop_assert_eq!(csc.nnz(), n);
prop_assert_eq!(DMatrix::from(&csc), DMatrix::identity(n, n));
}
}

View File

@ -0,0 +1,601 @@
use nalgebra::DMatrix;
use nalgebra_sparse::csr::CsrMatrix;
use nalgebra_sparse::{SparseEntry, SparseEntryMut, SparseFormatErrorKind};
use proptest::prelude::*;
use proptest::sample::subsequence;
use crate::assert_panics;
use crate::common::csr_strategy;
use std::collections::HashSet;
#[test]
fn csr_matrix_valid_data() {
// Construct matrix from valid data and check that selected methods return results
// that agree with expectations.
{
// A CSR matrix with zero explicitly stored entries
let offsets = vec![0, 0, 0, 0];
let indices = vec![];
let values = Vec::<i32>::new();
let mut matrix = CsrMatrix::try_from_csr_data(3, 2, offsets, indices, values).unwrap();
assert_eq!(matrix, CsrMatrix::zeros(3, 2));
assert_eq!(matrix.nrows(), 3);
assert_eq!(matrix.ncols(), 2);
assert_eq!(matrix.nnz(), 0);
assert_eq!(matrix.row_offsets(), &[0, 0, 0, 0]);
assert_eq!(matrix.col_indices(), &[]);
assert_eq!(matrix.values(), &[]);
assert!(matrix.triplet_iter().next().is_none());
assert!(matrix.triplet_iter_mut().next().is_none());
assert_eq!(matrix.row(0).ncols(), 2);
assert_eq!(matrix.row(0).nnz(), 0);
assert_eq!(matrix.row(0).col_indices(), &[]);
assert_eq!(matrix.row(0).values(), &[]);
assert_eq!(matrix.row_mut(0).ncols(), 2);
assert_eq!(matrix.row_mut(0).nnz(), 0);
assert_eq!(matrix.row_mut(0).col_indices(), &[]);
assert_eq!(matrix.row_mut(0).values(), &[]);
assert_eq!(matrix.row_mut(0).values_mut(), &[]);
assert_eq!(
matrix.row_mut(0).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(1).ncols(), 2);
assert_eq!(matrix.row(1).nnz(), 0);
assert_eq!(matrix.row(1).col_indices(), &[]);
assert_eq!(matrix.row(1).values(), &[]);
assert_eq!(matrix.row_mut(1).ncols(), 2);
assert_eq!(matrix.row_mut(1).nnz(), 0);
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
assert_eq!(matrix.row_mut(1).values(), &[]);
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
assert_eq!(
matrix.row_mut(1).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(2).ncols(), 2);
assert_eq!(matrix.row(2).nnz(), 0);
assert_eq!(matrix.row(2).col_indices(), &[]);
assert_eq!(matrix.row(2).values(), &[]);
assert_eq!(matrix.row_mut(2).ncols(), 2);
assert_eq!(matrix.row_mut(2).nnz(), 0);
assert_eq!(matrix.row_mut(2).col_indices(), &[]);
assert_eq!(matrix.row_mut(2).values(), &[]);
assert_eq!(matrix.row_mut(2).values_mut(), &[]);
assert_eq!(
matrix.row_mut(2).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert!(matrix.get_row(3).is_none());
assert!(matrix.get_row_mut(3).is_none());
let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets, vec![0, 0, 0, 0]);
assert_eq!(indices, vec![]);
assert_eq!(values, vec![]);
}
{
// An arbitrary CSR matrix
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let mut matrix =
CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
.unwrap();
assert_eq!(matrix.nrows(), 3);
assert_eq!(matrix.ncols(), 6);
assert_eq!(matrix.nnz(), 5);
assert_eq!(matrix.row_offsets(), &[0, 2, 2, 5]);
assert_eq!(matrix.col_indices(), &[0, 5, 1, 2, 3]);
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
assert_eq!(
matrix
.triplet_iter()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(
matrix
.triplet_iter_mut()
.map(|(i, j, v)| (i, j, *v))
.collect::<Vec<_>>(),
expected_triplets
);
assert_eq!(matrix.row(0).ncols(), 6);
assert_eq!(matrix.row(0).nnz(), 2);
assert_eq!(matrix.row(0).col_indices(), &[0, 5]);
assert_eq!(matrix.row(0).values(), &[0, 1]);
assert_eq!(matrix.row_mut(0).ncols(), 6);
assert_eq!(matrix.row_mut(0).nnz(), 2);
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
assert_eq!(
matrix.row_mut(0).cols_and_values_mut(),
([0, 5].as_ref(), [0, 1].as_mut())
);
assert_eq!(matrix.row(1).ncols(), 6);
assert_eq!(matrix.row(1).nnz(), 0);
assert_eq!(matrix.row(1).col_indices(), &[]);
assert_eq!(matrix.row(1).values(), &[]);
assert_eq!(matrix.row_mut(1).ncols(), 6);
assert_eq!(matrix.row_mut(1).nnz(), 0);
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
assert_eq!(matrix.row_mut(1).values(), &[]);
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
assert_eq!(
matrix.row_mut(1).cols_and_values_mut(),
([].as_ref(), [].as_mut())
);
assert_eq!(matrix.row(2).ncols(), 6);
assert_eq!(matrix.row(2).nnz(), 3);
assert_eq!(matrix.row(2).col_indices(), &[1, 2, 3]);
assert_eq!(matrix.row(2).values(), &[2, 3, 4]);
assert_eq!(matrix.row_mut(2).ncols(), 6);
assert_eq!(matrix.row_mut(2).nnz(), 3);
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
assert_eq!(
matrix.row_mut(2).cols_and_values_mut(),
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
);
assert!(matrix.get_row(3).is_none());
assert!(matrix.get_row_mut(3).is_none());
let (offsets2, indices2, values2) = matrix.disassemble();
assert_eq!(offsets2, offsets);
assert_eq!(indices2, indices);
assert_eq!(values2, values);
}
}
#[test]
fn csr_matrix_try_from_invalid_csr_data() {
{
// Empty offset array (invalid length)
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Offset array invalid length for arbitrary data
let offsets = vec![0, 3, 5];
let indices = vec![0, 1, 2, 3, 5];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid first entry in offsets array
let offsets = vec![1, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid last entry in offsets array
let offsets = vec![0, 2, 2, 4];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Invalid length of offsets array
let offsets = vec![0, 2, 2];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Nonmonotonic offsets
let offsets = vec![0, 3, 2, 5];
let indices = vec![0, 1, 2, 3, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Nonmonotonic minor indices
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 2, 3, 1, 4];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::InvalidStructure
);
}
{
// Minor index out of bounds
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 6, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::IndexOutOfBounds
);
}
{
// Duplicate entry
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 2, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
assert_eq!(
matrix.unwrap_err().kind(),
&SparseFormatErrorKind::DuplicateEntry
);
}
}
#[test]
fn csr_disassemble_avoids_clone_when_owned() {
// Test that disassemble avoids cloning the sparsity pattern when it holds the sole reference
// to the pattern. We do so by checking that the pointer to the data is unchanged.
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let values = vec![0, 1, 2, 3, 4];
let offsets_ptr = offsets.as_ptr();
let indices_ptr = indices.as_ptr();
let values_ptr = values.as_ptr();
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values).unwrap();
let (offsets, indices, values) = matrix.disassemble();
assert_eq!(offsets.as_ptr(), offsets_ptr);
assert_eq!(indices.as_ptr(), indices_ptr);
assert_eq!(values.as_ptr(), values_ptr);
}
// Rustfmt makes this test much harder to read by expanding some of the one-liners to 4-liners,
// so for now we skip rustfmt...
#[rustfmt::skip]
#[test]
fn csr_matrix_get_index_entry() {
// Test .get_entry(_mut) and .index_entry(_mut) methods
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(2, 3, &[
1, 0, 3,
0, 5, 6
]);
let csr = CsrMatrix::from(&dense);
assert_eq!(csr.get_entry(0, 0), Some(SparseEntry::NonZero(&1)));
assert_eq!(csr.index_entry(0, 0), SparseEntry::NonZero(&1));
assert_eq!(csr.get_entry(0, 1), Some(SparseEntry::Zero));
assert_eq!(csr.index_entry(0, 1), SparseEntry::Zero);
assert_eq!(csr.get_entry(0, 2), Some(SparseEntry::NonZero(&3)));
assert_eq!(csr.index_entry(0, 2), SparseEntry::NonZero(&3));
assert_eq!(csr.get_entry(1, 0), Some(SparseEntry::Zero));
assert_eq!(csr.index_entry(1, 0), SparseEntry::Zero);
assert_eq!(csr.get_entry(1, 1), Some(SparseEntry::NonZero(&5)));
assert_eq!(csr.index_entry(1, 1), SparseEntry::NonZero(&5));
assert_eq!(csr.get_entry(1, 2), Some(SparseEntry::NonZero(&6)));
assert_eq!(csr.index_entry(1, 2), SparseEntry::NonZero(&6));
// Check some out of bounds with .get_entry
assert_eq!(csr.get_entry(0, 3), None);
assert_eq!(csr.get_entry(0, 4), None);
assert_eq!(csr.get_entry(1, 3), None);
assert_eq!(csr.get_entry(1, 4), None);
assert_eq!(csr.get_entry(2, 0), None);
assert_eq!(csr.get_entry(2, 1), None);
assert_eq!(csr.get_entry(2, 2), None);
assert_eq!(csr.get_entry(2, 3), None);
assert_eq!(csr.get_entry(2, 4), None);
// Check that out of bounds with .index_entry panics
assert_panics!(csr.index_entry(0, 3));
assert_panics!(csr.index_entry(0, 4));
assert_panics!(csr.index_entry(1, 3));
assert_panics!(csr.index_entry(1, 4));
assert_panics!(csr.index_entry(2, 0));
assert_panics!(csr.index_entry(2, 1));
assert_panics!(csr.index_entry(2, 2));
assert_panics!(csr.index_entry(2, 3));
assert_panics!(csr.index_entry(2, 4));
{
// Check mutable versions of the above functions
let mut csr = csr;
assert_eq!(csr.get_entry_mut(0, 0), Some(SparseEntryMut::NonZero(&mut 1)));
assert_eq!(csr.index_entry_mut(0, 0), SparseEntryMut::NonZero(&mut 1));
assert_eq!(csr.get_entry_mut(0, 1), Some(SparseEntryMut::Zero));
assert_eq!(csr.index_entry_mut(0, 1), SparseEntryMut::Zero);
assert_eq!(csr.get_entry_mut(0, 2), Some(SparseEntryMut::NonZero(&mut 3)));
assert_eq!(csr.index_entry_mut(0, 2), SparseEntryMut::NonZero(&mut 3));
assert_eq!(csr.get_entry_mut(1, 0), Some(SparseEntryMut::Zero));
assert_eq!(csr.index_entry_mut(1, 0), SparseEntryMut::Zero);
assert_eq!(csr.get_entry_mut(1, 1), Some(SparseEntryMut::NonZero(&mut 5)));
assert_eq!(csr.index_entry_mut(1, 1), SparseEntryMut::NonZero(&mut 5));
assert_eq!(csr.get_entry_mut(1, 2), Some(SparseEntryMut::NonZero(&mut 6)));
assert_eq!(csr.index_entry_mut(1, 2), SparseEntryMut::NonZero(&mut 6));
// Check some out of bounds with .get_entry_mut
assert_eq!(csr.get_entry_mut(0, 3), None);
assert_eq!(csr.get_entry_mut(0, 4), None);
assert_eq!(csr.get_entry_mut(1, 3), None);
assert_eq!(csr.get_entry_mut(1, 4), None);
assert_eq!(csr.get_entry_mut(2, 0), None);
assert_eq!(csr.get_entry_mut(2, 1), None);
assert_eq!(csr.get_entry_mut(2, 2), None);
assert_eq!(csr.get_entry_mut(2, 3), None);
assert_eq!(csr.get_entry_mut(2, 4), None);
// Check that out of bounds with .index_entry_mut panics
// Note: the cloning is necessary because a mutable reference is not UnwindSafe
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(0, 3); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(0, 4); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(1, 3); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(1, 4); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 0); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 1); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 2); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 3); });
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 4); });
}
}
#[test]
fn csr_matrix_row_iter() {
#[rustfmt::skip]
let dense = DMatrix::from_row_slice(3, 4, &[
0, 1, 2, 0,
3, 0, 0, 0,
0, 4, 0, 5
]);
let csr = CsrMatrix::from(&dense);
// Immutable iterator
{
let mut row_iter = csr.row_iter();
{
let row = row_iter.next().unwrap();
assert_eq!(row.ncols(), 4);
assert_eq!(row.nnz(), 2);
assert_eq!(row.col_indices(), &[1, 2]);
assert_eq!(row.values(), &[1, 2]);
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&1)));
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(4), None);
}
{
let row = row_iter.next().unwrap();
assert_eq!(row.ncols(), 4);
assert_eq!(row.nnz(), 1);
assert_eq!(row.col_indices(), &[0]);
assert_eq!(row.values(), &[3]);
assert_eq!(row.get_entry(0), Some(SparseEntry::NonZero(&3)));
assert_eq!(row.get_entry(1), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(4), None);
}
{
let row = row_iter.next().unwrap();
assert_eq!(row.ncols(), 4);
assert_eq!(row.nnz(), 2);
assert_eq!(row.col_indices(), &[1, 3]);
assert_eq!(row.values(), &[4, 5]);
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&4)));
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
assert_eq!(row.get_entry(4), None);
}
assert!(row_iter.next().is_none());
}
// Mutable iterator
{
let mut csr = csr;
let mut row_iter = csr.row_iter_mut();
{
let mut row = row_iter.next().unwrap();
assert_eq!(row.ncols(), 4);
assert_eq!(row.nnz(), 2);
assert_eq!(row.col_indices(), &[1, 2]);
assert_eq!(row.values(), &[1, 2]);
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&1)));
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(4), None);
assert_eq!(row.values_mut(), &mut [1, 2]);
assert_eq!(
row.cols_and_values_mut(),
([1, 2].as_ref(), [1, 2].as_mut())
);
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 1)));
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(4), None);
}
{
let mut row = row_iter.next().unwrap();
assert_eq!(row.ncols(), 4);
assert_eq!(row.nnz(), 1);
assert_eq!(row.col_indices(), &[0]);
assert_eq!(row.values(), &[3]);
assert_eq!(row.get_entry(0), Some(SparseEntry::NonZero(&3)));
assert_eq!(row.get_entry(1), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(4), None);
assert_eq!(row.values_mut(), &mut [3]);
assert_eq!(row.cols_and_values_mut(), ([0].as_ref(), [3].as_mut()));
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::NonZero(&mut 3)));
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(4), None);
}
{
let mut row = row_iter.next().unwrap();
assert_eq!(row.ncols(), 4);
assert_eq!(row.nnz(), 2);
assert_eq!(row.col_indices(), &[1, 3]);
assert_eq!(row.values(), &[4, 5]);
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&4)));
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
assert_eq!(row.get_entry(4), None);
assert_eq!(row.values_mut(), &mut [4, 5]);
assert_eq!(
row.cols_and_values_mut(),
([1, 3].as_ref(), [4, 5].as_mut())
);
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 4)));
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
assert_eq!(row.get_entry_mut(4), None);
}
assert!(row_iter.next().is_none());
}
}
proptest! {
#[test]
fn csr_double_transpose_is_identity(csr in csr_strategy()) {
prop_assert_eq!(csr.transpose().transpose(), csr);
}
#[test]
fn csr_transpose_agrees_with_dense(csr in csr_strategy()) {
let dense_transpose = DMatrix::from(&csr).transpose();
let csr_transpose = csr.transpose();
prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose));
prop_assert_eq!(csr.nnz(), csr_transpose.nnz());
}
#[test]
fn csr_filter(
(csr, triplet_subset)
in csr_strategy()
.prop_flat_map(|matrix| {
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
let subset = subsequence(triplets, 0 ..= matrix.nnz())
.prop_map(|triplet_subset| {
let set: HashSet<_> = triplet_subset.into_iter().collect();
set
});
(Just(matrix), subset)
}))
{
// We generate a CsrMatrix and a HashSet corresponding to a subset of the (i, j, v)
// values in the matrix, which we use for filtering the matrix entries.
// The resulting triplets in the filtered matrix must then be exactly equal to
// the subset.
let filtered = csr.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
let filtered_triplets: HashSet<_> = filtered
.triplet_iter()
.cloned_values()
.collect();
prop_assert_eq!(filtered_triplets, triplet_subset);
}
#[test]
fn csr_lower_triangle_agrees_with_dense(csr in csr_strategy()) {
let csr_lower_triangle = csr.lower_triangle();
prop_assert_eq!(DMatrix::from(&csr_lower_triangle), DMatrix::from(&csr).lower_triangle());
prop_assert!(csr_lower_triangle.nnz() <= csr.nnz());
}
#[test]
fn csr_upper_triangle_agrees_with_dense(csr in csr_strategy()) {
let csr_upper_triangle = csr.upper_triangle();
prop_assert_eq!(DMatrix::from(&csr_upper_triangle), DMatrix::from(&csr).upper_triangle());
prop_assert!(csr_upper_triangle.nnz() <= csr.nnz());
}
#[test]
fn csr_diagonal_as_csr(csr in csr_strategy()) {
let d = csr.diagonal_as_csr();
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
let csr_diagonal_entries: HashSet<_> = csr
.triplet_iter()
.cloned_values()
.filter(|&(i, j, _)| i == j)
.collect();
prop_assert_eq!(d_entries, csr_diagonal_entries);
}
#[test]
fn csr_identity(n in 0 ..= 6usize) {
let csr = CsrMatrix::<i32>::identity(n);
prop_assert_eq!(csr.nnz(), n);
prop_assert_eq!(DMatrix::from(&csr), DMatrix::identity(n, n));
}
}

View File

@ -0,0 +1,8 @@
mod cholesky;
mod convert_serial;
mod coo;
mod csc;
mod csr;
mod ops;
mod pattern;
mod proptest;

View File

@ -0,0 +1,14 @@
# Seeds for failure cases proptest has generated in the past. It is
# automatically read and these particular cases re-run before any
# novel cases are generated.
#
# It is recommended to check this file in to source control so that
# everyone who runs the test benefits from these saved cases.
cc 6748ea4ac9523fcc4dd8327b27c6818f8df10eb2042774f59a6e3fa3205dbcbd # shrinks to (beta, alpha, (c, a, b)) = (0, -1, (Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 1, 5, -4, 2], nrows: Dynamic { value: 3 }, ncols: Dynamic { value: 3 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 2, 2, 2], minor_indices: [0, 1], minor_dim: 5 }, values: [-5, -2] }, Matrix { data: VecStorage { data: [4, -2, -3, -3, -5, 3, 5, 1, -4, -4, 3, 5, 5, 5, -3], nrows: Dynamic { value: 5 }, ncols: Dynamic { value: 3 } } }))
cc dcf67ab7b8febf109cfa58ee0f082b9f7c23d6ad0df2e28dc99984deeb6b113a # shrinks to (beta, alpha, (c, a, b)) = (0, 0, (Matrix { data: VecStorage { data: [0, -1], nrows: Dynamic { value: 1 }, ncols: Dynamic { value: 2 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0], minor_indices: [], minor_dim: 4 }, values: [] }, Matrix { data: VecStorage { data: [3, 1, 1, 0, 0, 3, -5, -3], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 2 } } }))
cc dbaef9886eaad28be7cd48326b857f039d695bc0b19e9ada3304e812e984d2c3 # shrinks to (beta, alpha, (c, a, b)) = (0, -1, (Matrix { data: VecStorage { data: [1], nrows: Dynamic { value: 1 }, ncols: Dynamic { value: 1 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0], minor_indices: [], minor_dim: 0 }, values: [] }, Matrix { data: VecStorage { data: [], nrows: Dynamic { value: 0 }, ncols: Dynamic { value: 1 } } }))
cc 99e312beb498ffa79194f41501ea312dce1911878eba131282904ac97205aaa9 # shrinks to SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrDenseArgs { c: Matrix { data: VecStorage { data: [-1, 4, -1, -4, 2, 1, 4, -2, 1, 3, -2, 5], nrows: Dynamic { value: 2 }, ncols: Dynamic { value: 6 } } }, beta: 0, alpha: 0, trans_a: Transpose, a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 1, 1, 1, 1, 1, 1], minor_indices: [0], minor_dim: 2 }, values: [0] }, trans_b: Transpose, b: Matrix { data: VecStorage { data: [-1, 1, 0, -5, 4, -5, 2, 2, 4, -4, -3, -1, 1, -1, 0, 1, -3, 4, -5, 0, 1, -5, 0, 1, 1, -3, 5, 3, 5, -3, -5, 3, -1, -4, -4, -3], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 6 } } } }
cc bf74259df2db6eda24eb42098e57ea1c604bb67d6d0023fa308c321027b53a43 # shrinks to (alpha, beta, c, a, b, trans_a, trans_b) = (0, 0, Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 3, 6, 9, 12], minor_indices: [0, 1, 3, 1, 2, 3, 0, 1, 2, 1, 2, 3], minor_dim: 4 }, values: [-3, 3, -3, 1, -3, 0, 2, 1, 3, 0, -4, -1] }, Matrix { data: VecStorage { data: [3, 1, 4, -5, 5, -2, -5, -1, 1, -1, 3, -3, -2, 4, 2, -1, -1, 3, -5, 5], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, NoTranspose, NoTranspose)
cc cbd6dac45a2f610e10cf4c15d4614cdbf7dfedbfcd733e4cc65c2e79829d14b3 # shrinks to SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrArgs { c: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 1, 1, 1, 1], minor_indices: [0], minor_dim: 1 }, values: [0] }, beta: 0, alpha: 1, trans_a: Transpose(true), a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 0, 1, 1, 1], minor_indices: [1], minor_dim: 5 }, values: [-1] }, trans_b: Transpose(true), b: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 2], minor_indices: [2, 4], minor_dim: 5 }, values: [-1, 0] } }
cc 8af78e2e41087743c8696c4d5563d59464f284662ccf85efc81ac56747d528bb # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 6, 12, 18, 24, 30, 33], minor_indices: [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 1, 2, 5], minor_dim: 6 }, values: [0.4566433975117654, -0.5109683327713039, 0.0, -3.276901622678194, 0.0, -2.2065487385437095, 0.0, -0.42643054427847016, -2.9232369281581234, 0.0, 1.2913925579441763, 0.0, -1.4073766622090917, -4.795473113569459, 4.681765156869446, -0.821162215887913, 3.0315816068414794, -3.3986924718213407, -3.498903007282241, -3.1488953408335236, 3.458104636152161, -4.774694888508124, 2.603884664757498, 0.0, 0.0, -3.2650988857765535, 4.26699442646613, 0.0, -0.012223422086023561, 3.6899095325779285, -1.4264458042247958, 0.0, 3.4849193883471266] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.9513896933988457, -4.426942420881461, 0.0, 0.0, 0.0, -0.28264084049240257], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 2 } } })
cc a4effd988fe352146fca365875e108ecf4f7d41f6ad54683e923ca6ce712e5d0 # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 5, 11, 17, 22, 27, 31], minor_indices: [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 3, 4, 5, 1, 2, 3, 4, 5, 0, 1, 3, 5], minor_dim: 6 }, values: [-2.24935510943371, -2.2288203680206227, 0.0, -1.029740125494273, 0.0, 0.0, 0.22632926934348507, -0.9123245943877407, 0.0, 3.8564332876991827, 0.0, 0.0, 0.0, -0.8235065737081717, 1.9337984046721566, 0.11003468246027737, -3.422112890579867, -3.7824068893569196, 0.0, -0.021700572247226546, -4.914783069982362, 0.6227245544506541, 0.0, 0.0, -4.411368879922364, -0.00013623178651567258, -2.613658177661417, -2.2783292441548637, 0.0, 1.351859435890189, -0.021345159183605134] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -4.519417607973404, 0.0, 0.0, 0.0, -0.21238483334481817], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 3 } } })

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,154 @@
use nalgebra_sparse::pattern::{SparsityPattern, SparsityPatternFormatError};
#[test]
fn sparsity_pattern_valid_data() {
// Construct pattern from valid data and check that selected methods return results
// that agree with expectations.
{
// A pattern with zero explicitly stored entries
let pattern =
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
.unwrap();
assert_eq!(pattern.major_dim(), 3);
assert_eq!(pattern.minor_dim(), 2);
assert_eq!(pattern.nnz(), 0);
assert_eq!(pattern.major_offsets(), &[0, 0, 0, 0]);
assert_eq!(pattern.minor_indices(), &[]);
assert_eq!(pattern.lane(0), &[]);
assert_eq!(pattern.lane(1), &[]);
assert_eq!(pattern.lane(2), &[]);
assert!(pattern.entries().next().is_none());
assert_eq!(pattern, SparsityPattern::zeros(3, 2));
let (offsets, indices) = pattern.disassemble();
assert_eq!(offsets, vec![0, 0, 0, 0]);
assert_eq!(indices, vec![]);
}
{
// Arbitrary pattern
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let pattern =
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
.unwrap();
assert_eq!(pattern.major_dim(), 3);
assert_eq!(pattern.minor_dim(), 6);
assert_eq!(pattern.major_offsets(), offsets.as_slice());
assert_eq!(pattern.minor_indices(), indices.as_slice());
assert_eq!(pattern.nnz(), 5);
assert_eq!(pattern.lane(0), &[0, 5]);
assert_eq!(pattern.lane(1), &[]);
assert_eq!(pattern.lane(2), &[1, 2, 3]);
assert_eq!(
pattern.entries().collect::<Vec<_>>(),
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]
);
let (offsets2, indices2) = pattern.disassemble();
assert_eq!(offsets2, offsets);
assert_eq!(indices2, indices);
}
}
#[test]
fn sparsity_pattern_try_from_invalid_data() {
{
// Empty offset array (invalid length)
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
assert_eq!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
);
}
{
// Offset array invalid length for arbitrary data
let offsets = vec![0, 3, 5];
let indices = vec![0, 1, 2, 3, 5];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
));
}
{
// Invalid first entry in offsets array
let offsets = vec![1, 2, 2, 5];
let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
));
}
{
// Invalid last entry in offsets array
let offsets = vec![0, 2, 2, 4];
let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
));
}
{
// Invalid length of offsets array
let offsets = vec![0, 2, 2];
let indices = vec![0, 5, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert!(matches!(
pattern,
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
));
}
{
// Nonmonotonic offsets
let offsets = vec![0, 3, 2, 5];
let indices = vec![0, 1, 2, 3, 4];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(
pattern,
Err(SparsityPatternFormatError::NonmonotonicOffsets)
);
}
{
// Nonmonotonic minor indices
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 2, 3, 1, 4];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(
pattern,
Err(SparsityPatternFormatError::NonmonotonicMinorIndices)
);
}
{
// Minor index out of bounds
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 6, 1, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(
pattern,
Err(SparsityPatternFormatError::MinorIndexOutOfBounds)
);
}
{
// Duplicate entry
let offsets = vec![0, 2, 2, 5];
let indices = vec![0, 5, 2, 2, 3];
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
assert_eq!(pattern, Err(SparsityPatternFormatError::DuplicateEntry));
}
}

View File

@ -0,0 +1,247 @@
#[test]
#[ignore]
fn coo_no_duplicates_generates_admissible_matrices() {
//TODO
}
#[cfg(feature = "slow-tests")]
mod slow {
use nalgebra::DMatrix;
use nalgebra_sparse::proptest::{
coo_no_duplicates, coo_with_duplicates, csc, csr, sparsity_pattern,
};
use itertools::Itertools;
use proptest::strategy::ValueTree;
use proptest::test_runner::TestRunner;
use proptest::prelude::*;
use nalgebra_sparse::csr::CsrMatrix;
use std::collections::HashSet;
use std::iter::repeat;
use std::ops::RangeInclusive;
fn generate_all_possible_matrices(
value_range: RangeInclusive<i32>,
rows_range: RangeInclusive<usize>,
cols_range: RangeInclusive<usize>,
) -> HashSet<DMatrix<i32>> {
// Enumerate all possible combinations
let mut all_combinations = HashSet::new();
for nrows in rows_range {
for ncols in cols_range.clone() {
// For the given number of rows and columns
let n_values = nrows * ncols;
if n_values == 0 {
// If we have zero rows or columns, the set of matrices with the given
// rows and columns is a single element: an empty matrix
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &[]));
} else {
// Otherwise, we need to sample all possible matrices.
// To do this, we generate the values as the (multi) Cartesian product
// of the value sets. For example, for a 2x2 matrices, we consider
// all possible 4-element arrays that the matrices can take by
// considering all elements in the cartesian product
// V x V x V x V
// where V is the set of eligible values, e.g. V := -1 ..= 1
let values_iter = repeat(value_range.clone())
.take(n_values)
.multi_cartesian_product();
for matrix_values in values_iter {
all_combinations.insert(DMatrix::from_row_slice(
nrows,
ncols,
&matrix_values,
));
}
}
}
}
all_combinations
}
#[cfg(feature = "slow-tests")]
#[test]
fn coo_no_duplicates_samples_all_admissible_outputs() {
// Note: This test basically mirrors a similar test for `matrix` in the `nalgebra` repo.
// Test that the proptest generation covers all possible outputs for a small space of inputs
// given enough samples.
// We use a deterministic test runner to make the test "stable".
let mut runner = TestRunner::deterministic();
// This number needs to be high enough so that we with high probability sample
// all possible cases
let num_generated_matrices = 500000;
let values = -1..=1;
let rows = 0..=2;
let cols = 0..=3;
let max_nnz = rows.end() * cols.end();
let strategy = coo_no_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz);
// Enumerate all possible combinations
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
}
#[cfg(feature = "slow-tests")]
#[test]
fn coo_with_duplicates_samples_all_admissible_outputs() {
// This is almost the same as the test for coo_no_duplicates, except that we need
// a different "success" criterion, since coo_with_duplicates is able to generate
// matrices with values outside of the value constraints. See below for details.
// We use a deterministic test runner to make the test "stable".
let mut runner = TestRunner::deterministic();
// This number needs to be high enough so that we with high probability sample
// all possible cases
let num_generated_matrices = 500000;
let values = -1..=1;
let rows = 0..=2;
let cols = 0..=3;
let max_nnz = rows.end() * cols.end();
let strategy = coo_with_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz, 2);
// Enumerate all possible combinations that fit the constraints
// (note: this is only a subset of the matrices that can be generated by
// `coo_with_duplicates`)
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
// Here we cannot verify that the set of visited combinations is *equal* to
// all possible outcomes with the given constraints, however the
// strategy should be able to generate all matrices that fit the constraints.
// In other words, we need to determine that set of all admissible matrices
// is contained in the set of visited matrices
assert!(all_combinations.is_subset(&visited_combinations));
}
#[cfg(feature = "slow-tests")]
#[test]
fn csr_samples_all_admissible_outputs() {
// We use a deterministic test runner to make the test "stable".
let mut runner = TestRunner::deterministic();
// This number needs to be high enough so that we with high probability sample
// all possible cases
let num_generated_matrices = 500000;
let values = -1..=1;
let rows = 0..=2;
let cols = 0..=3;
let max_nnz = rows.end() * cols.end();
let strategy = csr(values.clone(), rows.clone(), cols.clone(), max_nnz);
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
}
#[cfg(feature = "slow-tests")]
#[test]
fn csc_samples_all_admissible_outputs() {
// We use a deterministic test runner to make the test "stable".
let mut runner = TestRunner::deterministic();
// This number needs to be high enough so that we with high probability sample
// all possible cases
let num_generated_matrices = 500000;
let values = -1..=1;
let rows = 0..=2;
let cols = 0..=3;
let max_nnz = rows.end() * cols.end();
let strategy = csc(values.clone(), rows.clone(), cols.clone(), max_nnz);
let all_combinations = generate_all_possible_matrices(values, rows, cols);
let visited_combinations =
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
assert_eq!(visited_combinations.len(), all_combinations.len());
assert_eq!(
visited_combinations, all_combinations,
"Did not sample all possible values."
);
}
#[cfg(feature = "slow-tests")]
#[test]
fn sparsity_pattern_samples_all_admissible_outputs() {
let mut runner = TestRunner::deterministic();
let num_generated_patterns = 50000;
let major_dims = 0..=2;
let minor_dims = 0..=3;
let max_nnz = major_dims.end() * minor_dims.end();
let strategy = sparsity_pattern(major_dims.clone(), minor_dims.clone(), max_nnz);
let visited_patterns: HashSet<_> = sample_strategy(strategy, &mut runner)
.take(num_generated_patterns)
.map(|pattern| {
// We represent patterns as dense matrices with 1 if an entry is occupied,
// 0 otherwise
let values = vec![1; pattern.nnz()];
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
})
.map(|csr| DMatrix::from(&csr))
.collect();
let all_possible_patterns = generate_all_possible_matrices(0..=1, major_dims, minor_dims);
assert_eq!(visited_patterns.len(), all_possible_patterns.len());
assert_eq!(visited_patterns, all_possible_patterns);
}
fn sample_matrix_output_space<S>(
strategy: S,
runner: &mut TestRunner,
num_samples: usize,
) -> HashSet<DMatrix<i32>>
where
S: Strategy,
DMatrix<i32>: for<'b> From<&'b S::Value>,
{
sample_strategy(strategy, runner)
.take(num_samples)
.map(|matrix| DMatrix::from(&matrix))
.collect()
}
fn sample_strategy<'a, S: 'a + Strategy>(
strategy: S,
runner: &'a mut TestRunner,
) -> impl 'a + Iterator<Item = S::Value> {
repeat(()).map(move |_| {
let tree = strategy
.new_tree(runner)
.expect("Tree generation should not fail");
let value = tree.current();
value
})
}
}

View File

@ -1,6 +1,7 @@
//! Abstract definition of a matrix data storage allocator. //! Abstract definition of a matrix data storage allocator.
use std::any::Any; use std::any::Any;
use std::mem;
use crate::base::constraint::{SameNumberOfColumns, SameNumberOfRows, ShapeConstraint}; use crate::base::constraint::{SameNumberOfColumns, SameNumberOfRows, ShapeConstraint};
use crate::base::dimension::{Dim, U1}; use crate::base::dimension::{Dim, U1};
@ -21,7 +22,7 @@ pub trait Allocator<N: Scalar, R: Dim, C: Dim = U1>: Any + Sized {
type Buffer: ContiguousStorageMut<N, R, C> + Clone; type Buffer: ContiguousStorageMut<N, R, C> + Clone;
/// Allocates a buffer with the given number of rows and columns without initializing its content. /// Allocates a buffer with the given number of rows and columns without initializing its content.
unsafe fn allocate_uninitialized(nrows: R, ncols: C) -> Self::Buffer; unsafe fn allocate_uninitialized(nrows: R, ncols: C) -> mem::MaybeUninit<Self::Buffer>;
/// Allocates a buffer initialized with the content of the given iterator. /// Allocates a buffer initialized with the content of the given iterator.
fn allocate_from_iterator<I: IntoIterator<Item = N>>( fn allocate_from_iterator<I: IntoIterator<Item = N>>(

View File

@ -394,6 +394,26 @@ where
} }
} }
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar + bytemuck::Zeroable, R: DimName, C: DimName> bytemuck::Zeroable
for ArrayStorage<N, R, C>
where
R::Value: Mul<C::Value>,
Prod<R::Value, C::Value>: ArrayLength<N>,
Self: Copy,
{
}
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar + bytemuck::Pod, R: DimName, C: DimName> bytemuck::Pod
for ArrayStorage<N, R, C>
where
R::Value: Mul<C::Value>,
Prod<R::Value, C::Value>: ArrayLength<N>,
Self: Copy,
{
}
#[cfg(feature = "abomonation-serialize")] #[cfg(feature = "abomonation-serialize")]
impl<N, R, C> Abomonation for ArrayStorage<N, R, C> impl<N, R, C> Abomonation for ArrayStorage<N, R, C>
where where

View File

@ -1328,7 +1328,8 @@ where
ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>, ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
DefaultAllocator: Allocator<N, D1>, DefaultAllocator: Allocator<N, D1>,
{ {
let mut work = unsafe { Vector::new_uninitialized_generic(self.data.shape().0, U1) }; let mut work =
unsafe { crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, U1) };
self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta) self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta)
} }
@ -1421,7 +1422,8 @@ where
ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>, ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>,
DefaultAllocator: Allocator<N, D2>, DefaultAllocator: Allocator<N, D2>,
{ {
let mut work = unsafe { Vector::new_uninitialized_generic(mid.data.shape().0, U1) }; let mut work =
unsafe { crate::unimplemented_or_uninitialized_generic!(mid.data.shape().0, U1) };
self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta) self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta)
} }
} }

View File

@ -14,6 +14,7 @@ use rand::Rng;
#[cfg(feature = "std")] #[cfg(feature = "std")]
use rand_distr::StandardNormal; use rand_distr::StandardNormal;
use std::iter; use std::iter;
use std::mem;
use typenum::{self, Cmp, Greater}; use typenum::{self, Cmp, Greater};
#[cfg(feature = "std")] #[cfg(feature = "std")]
@ -25,6 +26,23 @@ use crate::base::dimension::{Dim, DimName, Dynamic, U1, U2, U3, U4, U5, U6};
use crate::base::storage::Storage; use crate::base::storage::Storage;
use crate::base::{DefaultAllocator, Matrix, MatrixMN, MatrixN, Scalar, Unit, Vector, VectorN}; use crate::base::{DefaultAllocator, Matrix, MatrixMN, MatrixN, Scalar, Unit, Vector, VectorN};
/// When "no_unsound_assume_init" is enabled, expands to `unimplemented!()` instead of `new_uninitialized_generic().assume_init()`.
/// Intended as a placeholder, each callsite should be refactored to use uninitialized memory soundly
#[macro_export]
macro_rules! unimplemented_or_uninitialized_generic {
($nrows:expr, $ncols:expr) => {{
#[cfg(feature="no_unsound_assume_init")] {
// Some of the call sites need the number of rows and columns from this to infer a type, so
// uninitialized memory is used to infer the type, as `N: Zero` isn't available at all callsites.
// This may technically still be UB even though the assume_init is dead code, but all callsites should be fixed before #556 is closed.
let typeinference_helper = crate::base::Matrix::new_uninitialized_generic($nrows, $ncols);
unimplemented!();
typeinference_helper.assume_init()
}
#[cfg(not(feature="no_unsound_assume_init"))] { crate::base::Matrix::new_uninitialized_generic($nrows, $ncols).assume_init() }
}}
}
/// # Generic constructors /// # Generic constructors
/// This set of matrix and vector construction functions are all generic /// This set of matrix and vector construction functions are all generic
/// with-regard to the matrix dimensions. They all expect to be given /// with-regard to the matrix dimensions. They all expect to be given
@ -38,8 +56,8 @@ where
/// Creates a new uninitialized matrix. If the matrix has a compile-time dimension, this panics /// Creates a new uninitialized matrix. If the matrix has a compile-time dimension, this panics
/// if `nrows != R::to_usize()` or `ncols != C::to_usize()`. /// if `nrows != R::to_usize()` or `ncols != C::to_usize()`.
#[inline] #[inline]
pub unsafe fn new_uninitialized_generic(nrows: R, ncols: C) -> Self { pub unsafe fn new_uninitialized_generic(nrows: R, ncols: C) -> mem::MaybeUninit<Self> {
Self::from_data(DefaultAllocator::allocate_uninitialized(nrows, ncols)) Self::from_uninitialized_data(DefaultAllocator::allocate_uninitialized(nrows, ncols))
} }
/// Creates a matrix with all its elements set to `elem`. /// Creates a matrix with all its elements set to `elem`.
@ -88,7 +106,7 @@ where
"Matrix init. error: the slice did not contain the right number of elements." "Matrix init. error: the slice did not contain the right number of elements."
); );
let mut res = unsafe { Self::new_uninitialized_generic(nrows, ncols) }; let mut res = unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
let mut iter = slice.iter(); let mut iter = slice.iter();
for i in 0..nrows.value() { for i in 0..nrows.value() {
@ -114,7 +132,7 @@ where
where where
F: FnMut(usize, usize) -> N, F: FnMut(usize, usize) -> N,
{ {
let mut res = unsafe { Self::new_uninitialized_generic(nrows, ncols) }; let mut res: Self = unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
for j in 0..ncols.value() { for j in 0..ncols.value() {
for i in 0..nrows.value() { for i in 0..nrows.value() {
@ -356,7 +374,7 @@ macro_rules! impl_constructors(
($($Dims: ty),*; $(=> $DimIdent: ident: $DimBound: ident),*; $($gargs: expr),*; $($args: ident),*) => { ($($Dims: ty),*; $(=> $DimIdent: ident: $DimBound: ident),*; $($gargs: expr),*; $($args: ident),*) => {
/// Creates a new uninitialized matrix or vector. /// Creates a new uninitialized matrix or vector.
#[inline] #[inline]
pub unsafe fn new_uninitialized($($args: usize),*) -> Self { pub unsafe fn new_uninitialized($($args: usize),*) -> mem::MaybeUninit<Self> {
Self::new_uninitialized_generic($($gargs),*) Self::new_uninitialized_generic($($gargs),*)
} }
@ -806,8 +824,8 @@ where
{ {
#[inline] #[inline]
fn sample<'a, G: Rng + ?Sized>(&self, rng: &'a mut G) -> MatrixMN<N, R, C> { fn sample<'a, G: Rng + ?Sized>(&self, rng: &'a mut G) -> MatrixMN<N, R, C> {
let nrows = R::try_to_usize().unwrap_or_else(|| rng.gen_range(0, 10)); let nrows = R::try_to_usize().unwrap_or_else(|| rng.gen_range(0..10));
let ncols = C::try_to_usize().unwrap_or_else(|| rng.gen_range(0, 10)); let ncols = C::try_to_usize().unwrap_or_else(|| rng.gen_range(0..10));
MatrixMN::from_fn_generic(R::from_usize(nrows), C::from_usize(ncols), |_, _| rng.gen()) MatrixMN::from_fn_generic(R::from_usize(nrows), C::from_usize(ncols), |_, _| rng.gen())
} }
@ -823,9 +841,9 @@ where
Owned<N, R, C>: Clone + Send, Owned<N, R, C>: Clone + Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
let nrows = R::try_to_usize().unwrap_or(g.gen_range(0, 10)); let nrows = R::try_to_usize().unwrap_or(usize::arbitrary(g) % 10);
let ncols = C::try_to_usize().unwrap_or(g.gen_range(0, 10)); let ncols = C::try_to_usize().unwrap_or(usize::arbitrary(g) % 10);
Self::from_fn_generic(R::from_usize(nrows), C::from_usize(ncols), |_, _| { Self::from_fn_generic(R::from_usize(nrows), C::from_usize(ncols), |_, _| {
N::arbitrary(g) N::arbitrary(g)
@ -865,7 +883,10 @@ macro_rules! componentwise_constructors_impl(
#[inline] #[inline]
pub fn new($($args: N),*) -> Self { pub fn new($($args: N),*) -> Self {
unsafe { unsafe {
let mut res = Self::new_uninitialized(); #[cfg(feature="no_unsound_assume_init")]
let mut res: Self = unimplemented!();
#[cfg(not(feature="no_unsound_assume_init"))]
let mut res = Self::new_uninitialized().assume_init();
$( *res.get_unchecked_mut(($irow, $icol)) = $args; )* $( *res.get_unchecked_mut(($irow, $icol)) = $args; )*
res res

View File

@ -50,7 +50,8 @@ where
let nrows2 = R2::from_usize(nrows); let nrows2 = R2::from_usize(nrows);
let ncols2 = C2::from_usize(ncols); let ncols2 = C2::from_usize(ncols);
let mut res = unsafe { MatrixMN::<N2, R2, C2>::new_uninitialized_generic(nrows2, ncols2) }; let mut res: MatrixMN<N2, R2, C2> =
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows2, ncols2) };
for i in 0..nrows { for i in 0..nrows {
for j in 0..ncols { for j in 0..ncols {
unsafe { unsafe {
@ -73,7 +74,7 @@ where
let nrows = R1::from_usize(nrows2); let nrows = R1::from_usize(nrows2);
let ncols = C1::from_usize(ncols2); let ncols = C1::from_usize(ncols2);
let mut res = unsafe { Self::new_uninitialized_generic(nrows, ncols) }; let mut res: Self = unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
for i in 0..nrows2 { for i in 0..nrows2 {
for j in 0..ncols2 { for j in 0..ncols2 {
unsafe { unsafe {
@ -117,9 +118,9 @@ macro_rules! impl_from_into_asref_1D(
fn from(arr: [N; $SZ]) -> Self { fn from(arr: [N; $SZ]) -> Self {
unsafe { unsafe {
let mut res = Self::new_uninitialized(); let mut res = Self::new_uninitialized();
ptr::copy_nonoverlapping(&arr[0], res.data.ptr_mut(), $SZ); ptr::copy_nonoverlapping(&arr[0], (*res.as_mut_ptr()).data.ptr_mut(), $SZ);
res res.assume_init()
} }
} }
} }
@ -184,9 +185,9 @@ macro_rules! impl_from_into_asref_2D(
fn from(arr: [[N; $SZRows]; $SZCols]) -> Self { fn from(arr: [[N; $SZRows]; $SZCols]) -> Self {
unsafe { unsafe {
let mut res = Self::new_uninitialized(); let mut res = Self::new_uninitialized();
ptr::copy_nonoverlapping(&arr[0][0], res.data.ptr_mut(), $SZRows * $SZCols); ptr::copy_nonoverlapping(&arr[0][0], (*res.as_mut_ptr()).data.ptr_mut(), $SZRows * $SZCols);
res res.assume_init()
} }
} }
} }
@ -244,9 +245,9 @@ macro_rules! impl_from_into_mint_1D(
fn from(v: mint::$VT<N>) -> Self { fn from(v: mint::$VT<N>) -> Self {
unsafe { unsafe {
let mut res = Self::new_uninitialized(); let mut res = Self::new_uninitialized();
ptr::copy_nonoverlapping(&v.x, res.data.ptr_mut(), $SZ); ptr::copy_nonoverlapping(&v.x, (*res.as_mut_ptr()).data.ptr_mut(), $SZ);
res res.assume_init()
} }
} }
} }
@ -306,13 +307,13 @@ macro_rules! impl_from_into_mint_2D(
fn from(m: mint::$MV<N>) -> Self { fn from(m: mint::$MV<N>) -> Self {
unsafe { unsafe {
let mut res = Self::new_uninitialized(); let mut res = Self::new_uninitialized();
let mut ptr = res.data.ptr_mut(); let mut ptr = (*res.as_mut_ptr()).data.ptr_mut();
$( $(
ptr::copy_nonoverlapping(&m.$component.x, ptr, $SZRows); ptr::copy_nonoverlapping(&m.$component.x, ptr, $SZRows);
ptr = ptr.offset($SZRows); ptr = ptr.offset($SZRows);
)* )*
let _ = ptr; let _ = ptr;
res res.assume_init()
} }
} }
} }

View File

@ -45,9 +45,8 @@ where
type Buffer = ArrayStorage<N, R, C>; type Buffer = ArrayStorage<N, R, C>;
#[inline] #[inline]
unsafe fn allocate_uninitialized(_: R, _: C) -> Self::Buffer { unsafe fn allocate_uninitialized(_: R, _: C) -> mem::MaybeUninit<Self::Buffer> {
// TODO: Undefined behavior, see #556 mem::MaybeUninit::<Self::Buffer>::uninit()
mem::MaybeUninit::<Self::Buffer>::uninit().assume_init()
} }
#[inline] #[inline]
@ -56,7 +55,10 @@ where
ncols: C, ncols: C,
iter: I, iter: I,
) -> Self::Buffer { ) -> Self::Buffer {
let mut res = unsafe { Self::allocate_uninitialized(nrows, ncols) }; #[cfg(feature = "no_unsound_assume_init")]
let mut res: Self::Buffer = unimplemented!();
#[cfg(not(feature = "no_unsound_assume_init"))]
let mut res = unsafe { Self::allocate_uninitialized(nrows, ncols).assume_init() };
let mut count = 0; let mut count = 0;
for (res, e) in res.iter_mut().zip(iter.into_iter()) { for (res, e) in res.iter_mut().zip(iter.into_iter()) {
@ -80,13 +82,13 @@ impl<N: Scalar, C: Dim> Allocator<N, Dynamic, C> for DefaultAllocator {
type Buffer = VecStorage<N, Dynamic, C>; type Buffer = VecStorage<N, Dynamic, C>;
#[inline] #[inline]
unsafe fn allocate_uninitialized(nrows: Dynamic, ncols: C) -> Self::Buffer { unsafe fn allocate_uninitialized(nrows: Dynamic, ncols: C) -> mem::MaybeUninit<Self::Buffer> {
let mut res = Vec::new(); let mut res = Vec::new();
let length = nrows.value() * ncols.value(); let length = nrows.value() * ncols.value();
res.reserve_exact(length); res.reserve_exact(length);
res.set_len(length); res.set_len(length);
VecStorage::new(nrows, ncols, res) mem::MaybeUninit::new(VecStorage::new(nrows, ncols, res))
} }
#[inline] #[inline]
@ -110,13 +112,13 @@ impl<N: Scalar, R: DimName> Allocator<N, R, Dynamic> for DefaultAllocator {
type Buffer = VecStorage<N, R, Dynamic>; type Buffer = VecStorage<N, R, Dynamic>;
#[inline] #[inline]
unsafe fn allocate_uninitialized(nrows: R, ncols: Dynamic) -> Self::Buffer { unsafe fn allocate_uninitialized(nrows: R, ncols: Dynamic) -> mem::MaybeUninit<Self::Buffer> {
let mut res = Vec::new(); let mut res = Vec::new();
let length = nrows.value() * ncols.value(); let length = nrows.value() * ncols.value();
res.reserve_exact(length); res.reserve_exact(length);
res.set_len(length); res.set_len(length);
VecStorage::new(nrows, ncols, res) mem::MaybeUninit::new(VecStorage::new(nrows, ncols, res))
} }
#[inline] #[inline]
@ -156,7 +158,11 @@ where
cto: CTo, cto: CTo,
buf: <Self as Allocator<N, RFrom, CFrom>>::Buffer, buf: <Self as Allocator<N, RFrom, CFrom>>::Buffer,
) -> ArrayStorage<N, RTo, CTo> { ) -> ArrayStorage<N, RTo, CTo> {
let mut res = <Self as Allocator<N, RTo, CTo>>::allocate_uninitialized(rto, cto); #[cfg(feature = "no_unsound_assume_init")]
let mut res: ArrayStorage<N, RTo, CTo> = unimplemented!();
#[cfg(not(feature = "no_unsound_assume_init"))]
let mut res =
<Self as Allocator<N, RTo, CTo>>::allocate_uninitialized(rto, cto).assume_init();
let (rfrom, cfrom) = buf.shape(); let (rfrom, cfrom) = buf.shape();
@ -184,7 +190,11 @@ where
cto: CTo, cto: CTo,
buf: ArrayStorage<N, RFrom, CFrom>, buf: ArrayStorage<N, RFrom, CFrom>,
) -> VecStorage<N, Dynamic, CTo> { ) -> VecStorage<N, Dynamic, CTo> {
let mut res = <Self as Allocator<N, Dynamic, CTo>>::allocate_uninitialized(rto, cto); #[cfg(feature = "no_unsound_assume_init")]
let mut res: VecStorage<N, Dynamic, CTo> = unimplemented!();
#[cfg(not(feature = "no_unsound_assume_init"))]
let mut res =
<Self as Allocator<N, Dynamic, CTo>>::allocate_uninitialized(rto, cto).assume_init();
let (rfrom, cfrom) = buf.shape(); let (rfrom, cfrom) = buf.shape();
@ -212,7 +222,11 @@ where
cto: Dynamic, cto: Dynamic,
buf: ArrayStorage<N, RFrom, CFrom>, buf: ArrayStorage<N, RFrom, CFrom>,
) -> VecStorage<N, RTo, Dynamic> { ) -> VecStorage<N, RTo, Dynamic> {
let mut res = <Self as Allocator<N, RTo, Dynamic>>::allocate_uninitialized(rto, cto); #[cfg(feature = "no_unsound_assume_init")]
let mut res: VecStorage<N, RTo, Dynamic> = unimplemented!();
#[cfg(not(feature = "no_unsound_assume_init"))]
let mut res =
<Self as Allocator<N, RTo, Dynamic>>::allocate_uninitialized(rto, cto).assume_init();
let (rfrom, cfrom) = buf.shape(); let (rfrom, cfrom) = buf.shape();

View File

@ -54,8 +54,9 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
{ {
let irows = irows.into_iter(); let irows = irows.into_iter();
let ncols = self.data.shape().1; let ncols = self.data.shape().1;
let mut res = let mut res = unsafe {
unsafe { MatrixMN::new_uninitialized_generic(Dynamic::new(irows.len()), ncols) }; crate::unimplemented_or_uninitialized_generic!(Dynamic::new(irows.len()), ncols)
};
// First, check that all the indices from irows are valid. // First, check that all the indices from irows are valid.
// This will allow us to use unchecked access in the inner loop. // This will allow us to use unchecked access in the inner loop.
@ -89,8 +90,9 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
{ {
let icols = icols.into_iter(); let icols = icols.into_iter();
let nrows = self.data.shape().0; let nrows = self.data.shape().0;
let mut res = let mut res = unsafe {
unsafe { MatrixMN::new_uninitialized_generic(nrows, Dynamic::new(icols.len())) }; crate::unimplemented_or_uninitialized_generic!(nrows, Dynamic::new(icols.len()))
};
for (destination, source) in icols.enumerate() { for (destination, source) in icols.enumerate() {
res.column_mut(destination).copy_from(&self.column(*source)) res.column_mut(destination).copy_from(&self.column(*source))
@ -896,7 +898,9 @@ impl<N: Scalar> DMatrix<N> {
where where
DefaultAllocator: Reallocator<N, Dynamic, Dynamic, Dynamic, Dynamic>, DefaultAllocator: Reallocator<N, Dynamic, Dynamic, Dynamic, Dynamic>,
{ {
let placeholder = unsafe { Self::new_uninitialized(0, 0) }; let placeholder = unsafe {
crate::unimplemented_or_uninitialized_generic!(Dynamic::new(0), Dynamic::new(0))
};
let old = mem::replace(self, placeholder); let old = mem::replace(self, placeholder);
let new = old.resize(new_nrows, new_ncols, val); let new = old.resize(new_nrows, new_ncols, val);
let _ = mem::replace(self, new); let _ = mem::replace(self, new);
@ -919,8 +923,9 @@ where
where where
DefaultAllocator: Reallocator<N, Dynamic, C, Dynamic, C>, DefaultAllocator: Reallocator<N, Dynamic, C, Dynamic, C>,
{ {
let placeholder = let placeholder = unsafe {
unsafe { Self::new_uninitialized_generic(Dynamic::new(0), self.data.shape().1) }; crate::unimplemented_or_uninitialized_generic!(Dynamic::new(0), self.data.shape().1)
};
let old = mem::replace(self, placeholder); let old = mem::replace(self, placeholder);
let new = old.resize_vertically(new_nrows, val); let new = old.resize_vertically(new_nrows, val);
let _ = mem::replace(self, new); let _ = mem::replace(self, new);
@ -943,8 +948,9 @@ where
where where
DefaultAllocator: Reallocator<N, R, Dynamic, R, Dynamic>, DefaultAllocator: Reallocator<N, R, Dynamic, R, Dynamic>,
{ {
let placeholder = let placeholder = unsafe {
unsafe { Self::new_uninitialized_generic(self.data.shape().0, Dynamic::new(0)) }; crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, Dynamic::new(0))
};
let old = mem::replace(self, placeholder); let old = mem::replace(self, placeholder);
let new = old.resize_horizontally(new_ncols, val); let new = old.resize_horizontally(new_ncols, val);
let _ = mem::replace(self, new); let _ = mem::replace(self, new);

View File

@ -7,7 +7,7 @@ use rand::Rng;
#[cfg(feature = "arbitrary")] #[cfg(feature = "arbitrary")]
#[doc(hidden)] #[doc(hidden)]
#[inline] #[inline]
pub fn reject<G: Gen, F: FnMut(&T) -> bool, T: Arbitrary>(g: &mut G, f: F) -> T { pub fn reject<F: FnMut(&T) -> bool, T: Arbitrary>(g: &mut Gen, f: F) -> T {
use std::iter; use std::iter;
iter::repeat(()) iter::repeat(())
.map(|_| Arbitrary::arbitrary(g)) .map(|_| Arbitrary::arbitrary(g))

View File

@ -1,5 +1,6 @@
//! Matrix iterators. //! Matrix iterators.
use std::iter::FusedIterator;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::mem; use std::mem;
@ -111,6 +112,46 @@ macro_rules! iterator {
} }
} }
impl<'a, N: Scalar, R: Dim, C: Dim, S: 'a + $Storage<N, R, C>> DoubleEndedIterator
for $Name<'a, N, R, C, S>
{
#[inline]
fn next_back(&mut self) -> Option<$Ref> {
unsafe {
if self.size == 0 {
None
} else {
// Pre-decrement `size` such that it now counts to the
// element we want to return.
self.size -= 1;
// Fetch strides
let inner_stride = self.strides.0.value();
let outer_stride = self.strides.1.value();
// Compute number of rows
// Division should be exact
let inner_raw_size = self.inner_end.offset_from(self.inner_ptr) as usize;
let inner_size = inner_raw_size / inner_stride;
// Compute rows and cols remaining
let outer_remaining = self.size / inner_size;
let inner_remaining = self.size % inner_size;
// Compute pointer to last element
let last = self.ptr.offset(
(outer_remaining * outer_stride + inner_remaining * inner_stride)
as isize,
);
// We want either `& *last` or `&mut *last` here, depending
// on the mutability of `$Ref`.
Some(mem::transmute(last))
}
}
}
}
impl<'a, N: Scalar, R: Dim, C: Dim, S: 'a + $Storage<N, R, C>> ExactSizeIterator impl<'a, N: Scalar, R: Dim, C: Dim, S: 'a + $Storage<N, R, C>> ExactSizeIterator
for $Name<'a, N, R, C, S> for $Name<'a, N, R, C, S>
{ {
@ -119,6 +160,11 @@ macro_rules! iterator {
self.size self.size
} }
} }
impl<'a, N: Scalar, R: Dim, C: Dim, S: 'a + $Storage<N, R, C>> FusedIterator
for $Name<'a, N, R, C, S>
{
}
}; };
} }

View File

@ -279,6 +279,22 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> matrixcompare_core::DenseAc
} }
} }
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> bytemuck::Zeroable
for Matrix<N, R, C, S>
where
S: bytemuck::Zeroable,
{
}
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> bytemuck::Pod for Matrix<N, R, C, S>
where
S: bytemuck::Pod,
Self: Copy,
{
}
impl<N: Scalar, R: Dim, C: Dim, S> Matrix<N, R, C, S> { impl<N: Scalar, R: Dim, C: Dim, S> Matrix<N, R, C, S> {
/// Creates a new matrix with the given data without statically checking that the matrix /// Creates a new matrix with the given data without statically checking that the matrix
/// dimension matches the storage dimension. /// dimension matches the storage dimension.
@ -298,6 +314,21 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
unsafe { Self::from_data_statically_unchecked(data) } unsafe { Self::from_data_statically_unchecked(data) }
} }
/// Creates a new uninitialized matrix with the given uninitialized data
pub unsafe fn from_uninitialized_data(data: mem::MaybeUninit<S>) -> mem::MaybeUninit<Self> {
let res: Matrix<N, R, C, mem::MaybeUninit<S>> = Matrix {
data,
_phantoms: PhantomData,
};
let res: mem::MaybeUninit<Matrix<N, R, C, mem::MaybeUninit<S>>> =
mem::MaybeUninit::new(res);
// safety: since we wrap the inner MaybeUninit in an outer MaybeUninit above, the fact that the `data` field is partially-uninitialized is still opaque.
// with s/transmute_copy/transmute/, rustc claims that `MaybeUninit<Matrix<N, R, C, MaybeUninit<S>>>` may be of a different size from `MaybeUninit<Matrix<N, R, C, S>>`
// but MaybeUninit's documentation says "MaybeUninit<T> is guaranteed to have the same size, alignment, and ABI as T", which implies those types should be the same size
let res: mem::MaybeUninit<Matrix<N, R, C, S>> = mem::transmute_copy(&res);
res
}
/// The shape of this matrix returned as the tuple (number of rows, number of columns). /// The shape of this matrix returned as the tuple (number of rows, number of columns).
/// ///
/// # Examples: /// # Examples:
@ -497,7 +528,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let ncols: SameShapeC<C, C2> = Dim::from_usize(ncols); let ncols: SameShapeC<C, C2> = Dim::from_usize(ncols);
let mut res: MatrixSum<N, R, C, R2, C2> = let mut res: MatrixSum<N, R, C, R2, C2> =
unsafe { Matrix::new_uninitialized_generic(nrows, ncols) }; unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
// TODO: use copy_from // TODO: use copy_from
for j in 0..res.ncols() { for j in 0..res.ncols() {
@ -546,7 +577,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
let (nrows, ncols) = self.data.shape(); let (nrows, ncols) = self.data.shape();
unsafe { unsafe {
let mut res = Matrix::new_uninitialized_generic(ncols, nrows); let mut res = crate::unimplemented_or_uninitialized_generic!(ncols, nrows);
self.transpose_to(&mut res); self.transpose_to(&mut res);
res res
@ -564,7 +595,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
{ {
let (nrows, ncols) = self.data.shape(); let (nrows, ncols) = self.data.shape();
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; let mut res: MatrixMN<N2, R, C> =
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
for j in 0..ncols.value() { for j in 0..ncols.value() {
for i in 0..nrows.value() { for i in 0..nrows.value() {
@ -608,7 +640,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
{ {
let (nrows, ncols) = self.data.shape(); let (nrows, ncols) = self.data.shape();
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; let mut res: MatrixMN<N2, R, C> =
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
for j in 0..ncols.value() { for j in 0..ncols.value() {
for i in 0..nrows.value() { for i in 0..nrows.value() {
@ -635,7 +668,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
{ {
let (nrows, ncols) = self.data.shape(); let (nrows, ncols) = self.data.shape();
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; let mut res: MatrixMN<N3, R, C> =
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
assert_eq!( assert_eq!(
(nrows.value(), ncols.value()), (nrows.value(), ncols.value()),
@ -676,7 +710,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
{ {
let (nrows, ncols) = self.data.shape(); let (nrows, ncols) = self.data.shape();
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) }; let mut res: MatrixMN<N4, R, C> =
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
assert_eq!( assert_eq!(
(nrows.value(), ncols.value()), (nrows.value(), ncols.value()),
@ -1170,7 +1205,8 @@ impl<N: SimdComplexField, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S
let (nrows, ncols) = self.data.shape(); let (nrows, ncols) = self.data.shape();
unsafe { unsafe {
let mut res: MatrixMN<_, C, R> = Matrix::new_uninitialized_generic(ncols, nrows); let mut res: MatrixMN<_, C, R> =
crate::unimplemented_or_uninitialized_generic!(ncols, nrows);
self.adjoint_to(&mut res); self.adjoint_to(&mut res);
res res
@ -1311,7 +1347,8 @@ impl<N: Scalar, D: Dim, S: Storage<N, D, D>> SquareMatrix<N, D, S> {
); );
let dim = self.data.shape().0; let dim = self.data.shape().0;
let mut res = unsafe { VectorN::new_uninitialized_generic(dim, U1) }; let mut res: VectorN<N2, D> =
unsafe { crate::unimplemented_or_uninitialized_generic!(dim, U1) };
for i in 0..dim.value() { for i in 0..dim.value() {
unsafe { unsafe {
@ -1438,7 +1475,8 @@ impl<N: Scalar + Zero, D: DimAdd<U1>, S: Storage<N, D>> Vector<N, D, S> {
{ {
let len = self.len(); let len = self.len();
let hnrows = DimSum::<D, U1>::from_usize(len + 1); let hnrows = DimSum::<D, U1>::from_usize(len + 1);
let mut res = unsafe { VectorN::<N, _>::new_uninitialized_generic(hnrows, U1) }; let mut res: VectorN<N, _> =
unsafe { crate::unimplemented_or_uninitialized_generic!(hnrows, U1) };
res.generic_slice_mut((0, 0), self.data.shape()) res.generic_slice_mut((0, 0), self.data.shape())
.copy_from(self); .copy_from(self);
res[(len, 0)] = element; res[(len, 0)] = element;
@ -1783,7 +1821,8 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
// TODO: soooo ugly! // TODO: soooo ugly!
let nrows = SameShapeR::<R, R2>::from_usize(3); let nrows = SameShapeR::<R, R2>::from_usize(3);
let ncols = SameShapeC::<C, C2>::from_usize(1); let ncols = SameShapeC::<C, C2>::from_usize(1);
let mut res = Matrix::new_uninitialized_generic(nrows, ncols); let mut res: MatrixCross<N, R, C, R2, C2> =
crate::unimplemented_or_uninitialized_generic!(nrows, ncols);
let ax = self.get_unchecked((0, 0)); let ax = self.get_unchecked((0, 0));
let ay = self.get_unchecked((1, 0)); let ay = self.get_unchecked((1, 0));
@ -1807,7 +1846,8 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
// TODO: ugly! // TODO: ugly!
let nrows = SameShapeR::<R, R2>::from_usize(1); let nrows = SameShapeR::<R, R2>::from_usize(1);
let ncols = SameShapeC::<C, C2>::from_usize(3); let ncols = SameShapeC::<C, C2>::from_usize(3);
let mut res = Matrix::new_uninitialized_generic(nrows, ncols); let mut res: MatrixCross<N, R, C, R2, C2> =
crate::unimplemented_or_uninitialized_generic!(nrows, ncols);
let ax = self.get_unchecked((0, 0)); let ax = self.get_unchecked((0, 0));
let ay = self.get_unchecked((0, 1)); let ay = self.get_unchecked((0, 1));

View File

@ -433,8 +433,8 @@ where
"Matrix meet/join error: mismatched dimensions." "Matrix meet/join error: mismatched dimensions."
); );
let mut mres = unsafe { Self::new_uninitialized_generic(shape.0, shape.1) }; let mut mres = unsafe { crate::unimplemented_or_uninitialized_generic!(shape.0, shape.1) };
let mut jres = unsafe { Self::new_uninitialized_generic(shape.0, shape.1) }; let mut jres = unsafe { crate::unimplemented_or_uninitialized_generic!(shape.0, shape.1) };
for i in 0..shape.0.value() * shape.1.value() { for i in 0..shape.0.value() * shape.1.value() {
unsafe { unsafe {

View File

@ -15,6 +15,7 @@ mod alias_slice;
mod array_storage; mod array_storage;
mod cg; mod cg;
mod componentwise; mod componentwise;
#[macro_use]
mod construction; mod construction;
mod construction_slice; mod construction_slice;
mod conversion; mod conversion;

View File

@ -8,7 +8,7 @@ use crate::allocator::Allocator;
use crate::base::{DefaultAllocator, Dim, DimName, Matrix, MatrixMN, Normed, VectorN}; use crate::base::{DefaultAllocator, Dim, DimName, Matrix, MatrixMN, Normed, VectorN};
use crate::constraint::{SameNumberOfColumns, SameNumberOfRows, ShapeConstraint}; use crate::constraint::{SameNumberOfColumns, SameNumberOfRows, ShapeConstraint};
use crate::storage::{Storage, StorageMut}; use crate::storage::{Storage, StorageMut};
use crate::{ComplexField, Scalar, SimdComplexField, Unit}; use crate::{ComplexField, RealField, Scalar, SimdComplexField, Unit};
use simba::scalar::ClosedNeg; use simba::scalar::ClosedNeg;
use simba::simd::{SimdOption, SimdPartialOrd}; use simba::simd::{SimdOption, SimdPartialOrd};
@ -334,11 +334,27 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
{ {
let n = self.norm(); let n = self.norm();
if n >= min_magnitude { if n > min_magnitude {
self.scale_mut(magnitude / n) self.scale_mut(magnitude / n)
} }
} }
/// Returns a new vector with the same magnitude as `self` clamped between `0.0` and `max`.
#[inline]
pub fn cap_magnitude(&self, max: N::RealField) -> MatrixMN<N, R, C>
where
N: RealField,
DefaultAllocator: Allocator<N, R, C>,
{
let n = self.norm();
if n > max {
self.scale(max / n)
} else {
self.clone_owned()
}
}
/// Returns a normalized version of this matrix unless its norm as smaller or equal to `eps`. /// Returns a normalized version of this matrix unless its norm as smaller or equal to `eps`.
/// ///
/// The components of this matrix cannot be SIMD types (see `simd_try_normalize`) instead. /// The components of this matrix cannot be SIMD types (see `simd_try_normalize`) instead.

View File

@ -331,7 +331,7 @@ macro_rules! componentwise_binop_impl(
let (nrows, ncols) = self.shape(); let (nrows, ncols) = self.shape();
let nrows: SameShapeR<R1, R2> = Dim::from_usize(nrows); let nrows: SameShapeR<R1, R2> = Dim::from_usize(nrows);
let ncols: SameShapeC<C1, C2> = Dim::from_usize(ncols); let ncols: SameShapeC<C1, C2> = Dim::from_usize(ncols);
Matrix::new_uninitialized_generic(nrows, ncols) crate::unimplemented_or_uninitialized_generic!(nrows, ncols)
}; };
self.$method_to_statically_unchecked(rhs, &mut res); self.$method_to_statically_unchecked(rhs, &mut res);
@ -573,9 +573,9 @@ where
#[inline] #[inline]
fn mul(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output { fn mul(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output {
let mut res = let mut res = unsafe {
unsafe { Matrix::new_uninitialized_generic(self.data.shape().0, rhs.data.shape().1) }; crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, rhs.data.shape().1)
};
self.mul_to(rhs, &mut res); self.mul_to(rhs, &mut res);
res res
} }
@ -684,8 +684,9 @@ where
DefaultAllocator: Allocator<N, C1, C2>, DefaultAllocator: Allocator<N, C1, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2>, ShapeConstraint: SameNumberOfRows<R1, R2>,
{ {
let mut res = let mut res = unsafe {
unsafe { Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1) }; crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
};
self.tr_mul_to(rhs, &mut res); self.tr_mul_to(rhs, &mut res);
res res
@ -700,8 +701,9 @@ where
DefaultAllocator: Allocator<N, C1, C2>, DefaultAllocator: Allocator<N, C1, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2>, ShapeConstraint: SameNumberOfRows<R1, R2>,
{ {
let mut res = let mut res = unsafe {
unsafe { Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1) }; crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
};
self.ad_mul_to(rhs, &mut res); self.ad_mul_to(rhs, &mut res);
res res
@ -815,8 +817,9 @@ where
let (nrows1, ncols1) = self.data.shape(); let (nrows1, ncols1) = self.data.shape();
let (nrows2, ncols2) = rhs.data.shape(); let (nrows2, ncols2) = rhs.data.shape();
let mut res = let mut res = unsafe {
unsafe { Matrix::new_uninitialized_generic(nrows1.mul(nrows2), ncols1.mul(ncols2)) }; crate::unimplemented_or_uninitialized_generic!(nrows1.mul(nrows2), ncols1.mul(ncols2))
};
{ {
let mut data_res = res.data.ptr_mut(); let mut data_res = res.data.ptr_mut();

View File

@ -17,7 +17,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
DefaultAllocator: Allocator<N, U1, C>, DefaultAllocator: Allocator<N, U1, C>,
{ {
let ncols = self.data.shape().1; let ncols = self.data.shape().1;
let mut res = unsafe { RowVectorN::new_uninitialized_generic(U1, ncols) }; let mut res: RowVectorN<N, C> =
unsafe { crate::unimplemented_or_uninitialized_generic!(U1, ncols) };
for i in 0..ncols.value() { for i in 0..ncols.value() {
// TODO: avoid bound checking of column. // TODO: avoid bound checking of column.
@ -42,7 +43,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
DefaultAllocator: Allocator<N, C>, DefaultAllocator: Allocator<N, C>,
{ {
let ncols = self.data.shape().1; let ncols = self.data.shape().1;
let mut res = unsafe { VectorN::new_uninitialized_generic(ncols, U1) }; let mut res: VectorN<N, C> =
unsafe { crate::unimplemented_or_uninitialized_generic!(ncols, U1) };
for i in 0..ncols.value() { for i in 0..ncols.value() {
// TODO: avoid bound checking of column. // TODO: avoid bound checking of column.

View File

@ -30,6 +30,12 @@ pub struct Unit<T> {
pub(crate) value: T, pub(crate) value: T,
} }
#[cfg(feature = "bytemuck")]
unsafe impl<T> bytemuck::Zeroable for Unit<T> where T: bytemuck::Zeroable {}
#[cfg(feature = "bytemuck")]
unsafe impl<T> bytemuck::Pod for Unit<T> where T: bytemuck::Pod {}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<T: Serialize> Serialize for Unit<T> { impl<T: Serialize> Serialize for Unit<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>

View File

@ -48,9 +48,8 @@ where
DefaultAllocator: Allocator<N, D, D>, DefaultAllocator: Allocator<N, D, D>,
Owned<N, D, D>: Clone + Send, Owned<N, D, D>: Clone + Send,
{ {
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
use rand::Rng; let dim = D::try_to_usize().unwrap_or(1 + usize::arbitrary(g) % 50);
let dim = D::try_to_usize().unwrap_or(g.gen_range(1, 50));
Self::new(D::from_usize(dim), || N::arbitrary(g)) Self::new(D::from_usize(dim), || N::arbitrary(g))
} }
} }

View File

@ -51,9 +51,8 @@ where
DefaultAllocator: Allocator<N, D, D>, DefaultAllocator: Allocator<N, D, D>,
Owned<N, D, D>: Clone + Send, Owned<N, D, D>: Clone + Send,
{ {
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
use rand::Rng; let dim = D::try_to_usize().unwrap_or(1 + usize::arbitrary(g) % 50);
let dim = D::try_to_usize().unwrap_or(g.gen_range(1, 50));
Self::new(D::from_usize(dim), || N::arbitrary(g)) Self::new(D::from_usize(dim), || N::arbitrary(g))
} }
} }

View File

@ -1,6 +1,13 @@
use crate::{Quaternion, SimdRealField}; use crate::{
Isometry3, Matrix4, Normed, Point3, Quaternion, Scalar, SimdRealField, Translation3, Unit,
UnitQuaternion, Vector3, VectorN, Zero, U8,
};
use approx::{AbsDiffEq, RelativeEq, UlpsEq};
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::fmt;
use simba::scalar::{ClosedNeg, RealField};
/// A dual quaternion. /// A dual quaternion.
/// ///
@ -28,14 +35,23 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
/// If a feature that you need is missing, feel free to open an issue or a PR. /// If a feature that you need is missing, feel free to open an issue or a PR.
/// See https://github.com/dimforge/nalgebra/issues/487 /// See https://github.com/dimforge/nalgebra/issues/487
#[repr(C)] #[repr(C)]
#[derive(Debug, Default, Eq, PartialEq, Copy, Clone)] #[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub struct DualQuaternion<N: SimdRealField> { pub struct DualQuaternion<N: Scalar> {
/// The real component of the quaternion /// The real component of the quaternion
pub real: Quaternion<N>, pub real: Quaternion<N>,
/// The dual component of the quaternion /// The dual component of the quaternion
pub dual: Quaternion<N>, pub dual: Quaternion<N>,
} }
impl<N: Scalar + Zero> Default for DualQuaternion<N> {
fn default() -> Self {
Self {
real: Quaternion::default(),
dual: Quaternion::default(),
}
}
}
impl<N: SimdRealField> DualQuaternion<N> impl<N: SimdRealField> DualQuaternion<N>
where where
N::Element: SimdRealField, N::Element: SimdRealField,
@ -77,8 +93,147 @@ where
/// relative_eq!(dq.real.norm(), 1.0); /// relative_eq!(dq.real.norm(), 1.0);
/// ``` /// ```
#[inline] #[inline]
pub fn normalize_mut(&mut self) { pub fn normalize_mut(&mut self) -> N {
*self = self.normalize(); let real_norm = self.real.norm();
self.real /= real_norm;
self.dual /= real_norm;
real_norm
}
/// The conjugate of this dual quaternion, containing the conjugate of
/// the real and imaginary parts..
///
/// # Example
/// ```
/// # use nalgebra::{DualQuaternion, Quaternion};
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let dq = DualQuaternion::from_real_and_dual(real, dual);
///
/// let conj = dq.conjugate();
/// assert!(conj.real.i == -2.0 && conj.real.j == -3.0 && conj.real.k == -4.0);
/// assert!(conj.real.w == 1.0);
/// assert!(conj.dual.i == -6.0 && conj.dual.j == -7.0 && conj.dual.k == -8.0);
/// assert!(conj.dual.w == 5.0);
/// ```
#[inline]
#[must_use = "Did you mean to use conjugate_mut()?"]
pub fn conjugate(&self) -> Self {
Self::from_real_and_dual(self.real.conjugate(), self.dual.conjugate())
}
/// Replaces this quaternion by its conjugate.
///
/// # Example
/// ```
/// # use nalgebra::{DualQuaternion, Quaternion};
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let mut dq = DualQuaternion::from_real_and_dual(real, dual);
///
/// dq.conjugate_mut();
/// assert!(dq.real.i == -2.0 && dq.real.j == -3.0 && dq.real.k == -4.0);
/// assert!(dq.real.w == 1.0);
/// assert!(dq.dual.i == -6.0 && dq.dual.j == -7.0 && dq.dual.k == -8.0);
/// assert!(dq.dual.w == 5.0);
/// ```
#[inline]
pub fn conjugate_mut(&mut self) {
self.real.conjugate_mut();
self.dual.conjugate_mut();
}
/// Inverts this dual quaternion if it is not zero.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{DualQuaternion, Quaternion};
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let dq = DualQuaternion::from_real_and_dual(real, dual);
/// let inverse = dq.try_inverse();
///
/// assert!(inverse.is_some());
/// assert_relative_eq!(inverse.unwrap() * dq, DualQuaternion::identity());
///
/// //Non-invertible case
/// let zero = Quaternion::new(0.0, 0.0, 0.0, 0.0);
/// let dq = DualQuaternion::from_real_and_dual(zero, zero);
/// let inverse = dq.try_inverse();
///
/// assert!(inverse.is_none());
/// ```
#[inline]
#[must_use = "Did you mean to use try_inverse_mut()?"]
pub fn try_inverse(&self) -> Option<Self>
where
N: RealField,
{
let mut res = *self;
if res.try_inverse_mut() {
Some(res)
} else {
None
}
}
/// Inverts this dual quaternion in-place if it is not zero.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{DualQuaternion, Quaternion};
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let dq = DualQuaternion::from_real_and_dual(real, dual);
/// let mut dq_inverse = dq;
/// dq_inverse.try_inverse_mut();
///
/// assert_relative_eq!(dq_inverse * dq, DualQuaternion::identity());
///
/// //Non-invertible case
/// let zero = Quaternion::new(0.0, 0.0, 0.0, 0.0);
/// let mut dq = DualQuaternion::from_real_and_dual(zero, zero);
/// assert!(!dq.try_inverse_mut());
/// ```
#[inline]
pub fn try_inverse_mut(&mut self) -> bool
where
N: RealField,
{
let inverted = self.real.try_inverse_mut();
if inverted {
self.dual = -self.real * self.dual * self.real;
true
} else {
false
}
}
/// Linear interpolation between two dual quaternions.
///
/// Computes `self * (1 - t) + other * t`.
///
/// # Example
/// ```
/// # use nalgebra::{DualQuaternion, Quaternion};
/// let dq1 = DualQuaternion::from_real_and_dual(
/// Quaternion::new(1.0, 0.0, 0.0, 4.0),
/// Quaternion::new(0.0, 2.0, 0.0, 0.0)
/// );
/// let dq2 = DualQuaternion::from_real_and_dual(
/// Quaternion::new(2.0, 0.0, 1.0, 0.0),
/// Quaternion::new(0.0, 2.0, 0.0, 0.0)
/// );
/// assert_eq!(dq1.lerp(&dq2, 0.25), DualQuaternion::from_real_and_dual(
/// Quaternion::new(1.25, 0.0, 0.25, 3.0),
/// Quaternion::new(0.0, 2.0, 0.0, 0.0)
/// ));
/// ```
#[inline]
pub fn lerp(&self, other: &Self, t: N) -> Self {
self * (N::one() - t) + other * t
} }
} }
@ -114,3 +269,669 @@ where
}) })
} }
} }
impl<N: RealField> DualQuaternion<N> {
fn to_vector(&self) -> VectorN<N, U8> {
self.as_ref().clone().into()
}
}
impl<N: RealField + AbsDiffEq<Epsilon = N>> AbsDiffEq for DualQuaternion<N> {
type Epsilon = N;
#[inline]
fn default_epsilon() -> Self::Epsilon {
N::default_epsilon()
}
#[inline]
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
self.to_vector().abs_diff_eq(&other.to_vector(), epsilon) ||
// Account for the double-covering of S², i.e. q = -q
self.to_vector().iter().zip(other.to_vector().iter()).all(|(a, b)| a.abs_diff_eq(&-*b, epsilon))
}
}
impl<N: RealField + RelativeEq<Epsilon = N>> RelativeEq for DualQuaternion<N> {
#[inline]
fn default_max_relative() -> Self::Epsilon {
N::default_max_relative()
}
#[inline]
fn relative_eq(
&self,
other: &Self,
epsilon: Self::Epsilon,
max_relative: Self::Epsilon,
) -> bool {
self.to_vector().relative_eq(&other.to_vector(), epsilon, max_relative) ||
// Account for the double-covering of S², i.e. q = -q
self.to_vector().iter().zip(other.to_vector().iter()).all(|(a, b)| a.relative_eq(&-*b, epsilon, max_relative))
}
}
impl<N: RealField + UlpsEq<Epsilon = N>> UlpsEq for DualQuaternion<N> {
#[inline]
fn default_max_ulps() -> u32 {
N::default_max_ulps()
}
#[inline]
fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
self.to_vector().ulps_eq(&other.to_vector(), epsilon, max_ulps) ||
// Account for the double-covering of S², i.e. q = -q.
self.to_vector().iter().zip(other.to_vector().iter()).all(|(a, b)| a.ulps_eq(&-*b, epsilon, max_ulps))
}
}
/// A unit quaternions. May be used to represent a rotation followed by a translation.
pub type UnitDualQuaternion<N> = Unit<DualQuaternion<N>>;
impl<N: Scalar + ClosedNeg + PartialEq + SimdRealField> PartialEq for UnitDualQuaternion<N> {
#[inline]
fn eq(&self, rhs: &Self) -> bool {
self.as_ref().eq(rhs.as_ref())
}
}
impl<N: Scalar + ClosedNeg + Eq + SimdRealField> Eq for UnitDualQuaternion<N> {}
impl<N: SimdRealField> Normed for DualQuaternion<N> {
type Norm = N::SimdRealField;
#[inline]
fn norm(&self) -> N::SimdRealField {
self.real.norm()
}
#[inline]
fn norm_squared(&self) -> N::SimdRealField {
self.real.norm_squared()
}
#[inline]
fn scale_mut(&mut self, n: Self::Norm) {
self.real.scale_mut(n);
self.dual.scale_mut(n);
}
#[inline]
fn unscale_mut(&mut self, n: Self::Norm) {
self.real.unscale_mut(n);
self.dual.unscale_mut(n);
}
}
impl<N: SimdRealField> UnitDualQuaternion<N>
where
N::Element: SimdRealField,
{
/// The underlying dual quaternion.
///
/// Same as `self.as_ref()`.
///
/// # Example
/// ```
/// # use nalgebra::{DualQuaternion, UnitDualQuaternion, Quaternion};
/// let id = UnitDualQuaternion::identity();
/// assert_eq!(*id.dual_quaternion(), DualQuaternion::from_real_and_dual(
/// Quaternion::new(1.0, 0.0, 0.0, 0.0),
/// Quaternion::new(0.0, 0.0, 0.0, 0.0)
/// ));
/// ```
#[inline]
pub fn dual_quaternion(&self) -> &DualQuaternion<N> {
self.as_ref()
}
/// Compute the conjugate of this unit quaternion.
///
/// # Example
/// ```
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let unit = UnitDualQuaternion::new_normalize(
/// DualQuaternion::from_real_and_dual(qr, qd)
/// );
/// let conj = unit.conjugate();
/// assert_eq!(conj.real, unit.real.conjugate());
/// assert_eq!(conj.dual, unit.dual.conjugate());
/// ```
#[inline]
#[must_use = "Did you mean to use conjugate_mut()?"]
pub fn conjugate(&self) -> Self {
Self::new_unchecked(self.as_ref().conjugate())
}
/// Compute the conjugate of this unit quaternion in-place.
///
/// # Example
/// ```
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let unit = UnitDualQuaternion::new_normalize(
/// DualQuaternion::from_real_and_dual(qr, qd)
/// );
/// let mut conj = unit.clone();
/// conj.conjugate_mut();
/// assert_eq!(conj.as_ref().real, unit.as_ref().real.conjugate());
/// assert_eq!(conj.as_ref().dual, unit.as_ref().dual.conjugate());
/// ```
#[inline]
pub fn conjugate_mut(&mut self) {
self.as_mut_unchecked().conjugate_mut()
}
/// Inverts this dual quaternion if it is not zero.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, Quaternion, DualQuaternion};
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let unit = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qr, qd));
/// let inv = unit.inverse();
/// assert_relative_eq!(unit * inv, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
/// assert_relative_eq!(inv * unit, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
/// ```
#[inline]
#[must_use = "Did you mean to use inverse_mut()?"]
pub fn inverse(&self) -> Self {
let real = Unit::new_unchecked(self.as_ref().real)
.inverse()
.into_inner();
let dual = -real * self.as_ref().dual * real;
UnitDualQuaternion::new_unchecked(DualQuaternion { real, dual })
}
/// Inverts this dual quaternion in place if it is not zero.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, Quaternion, DualQuaternion};
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let unit = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qr, qd));
/// let mut inv = unit.clone();
/// inv.inverse_mut();
/// assert_relative_eq!(unit * inv, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
/// assert_relative_eq!(inv * unit, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
/// ```
#[inline]
#[must_use = "Did you mean to use inverse_mut()?"]
pub fn inverse_mut(&mut self) {
let quat = self.as_mut_unchecked();
quat.real = Unit::new_unchecked(quat.real).inverse().into_inner();
quat.dual = -quat.real * quat.dual * quat.real;
}
/// The unit dual quaternion needed to make `self` and `other` coincide.
///
/// The result is such that: `self.isometry_to(other) * self == other`.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
/// let dq1 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qr, qd));
/// let dq2 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qd, qr));
/// let dq_to = dq1.isometry_to(&dq2);
/// assert_relative_eq!(dq_to * dq1, dq2, epsilon = 1.0e-6);
/// ```
#[inline]
pub fn isometry_to(&self, other: &Self) -> Self {
other / self
}
/// Linear interpolation between two unit dual quaternions.
///
/// The result is not normalized.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
/// let dq1 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
/// Quaternion::new(0.5, 0.0, 0.5, 0.0),
/// Quaternion::new(0.0, 0.5, 0.0, 0.5)
/// ));
/// let dq2 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
/// Quaternion::new(0.5, 0.0, 0.0, 0.5),
/// Quaternion::new(0.5, 0.0, 0.5, 0.0)
/// ));
/// assert_relative_eq!(
/// UnitDualQuaternion::new_normalize(dq1.lerp(&dq2, 0.5)),
/// UnitDualQuaternion::new_normalize(
/// DualQuaternion::from_real_and_dual(
/// Quaternion::new(0.5, 0.0, 0.25, 0.25),
/// Quaternion::new(0.25, 0.25, 0.25, 0.25)
/// )
/// ),
/// epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn lerp(&self, other: &Self, t: N) -> DualQuaternion<N> {
self.as_ref().lerp(other.as_ref(), t)
}
/// Normalized linear interpolation between two unit quaternions.
///
/// This is the same as `self.lerp` except that the result is normalized.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
/// let dq1 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
/// Quaternion::new(0.5, 0.0, 0.5, 0.0),
/// Quaternion::new(0.0, 0.5, 0.0, 0.5)
/// ));
/// let dq2 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
/// Quaternion::new(0.5, 0.0, 0.0, 0.5),
/// Quaternion::new(0.5, 0.0, 0.5, 0.0)
/// ));
/// assert_relative_eq!(dq1.nlerp(&dq2, 0.2), UnitDualQuaternion::new_normalize(
/// DualQuaternion::from_real_and_dual(
/// Quaternion::new(0.5, 0.0, 0.4, 0.1),
/// Quaternion::new(0.1, 0.4, 0.1, 0.4)
/// )
/// ), epsilon = 1.0e-6);
/// ```
#[inline]
pub fn nlerp(&self, other: &Self, t: N) -> Self {
let mut res = self.lerp(other, t);
let _ = res.normalize_mut();
Self::new_unchecked(res)
}
/// Screw linear interpolation between two unit quaternions. This creates a
/// smooth arc from one dual-quaternion to another.
///
/// Panics if the angle between both quaternion is 180 degrees (in which case the interpolation
/// is not well-defined). Use `.try_sclerp` instead to avoid the panic.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, UnitQuaternion, Vector3};
///
/// let dq1 = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_4, 0.0, 0.0),
/// );
///
/// let dq2 = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 0.0, 3.0).into(),
/// UnitQuaternion::from_euler_angles(-std::f32::consts::PI, 0.0, 0.0),
/// );
///
/// let dq = dq1.sclerp(&dq2, 1.0 / 3.0);
///
/// assert_relative_eq!(
/// dq.rotation().euler_angles().0, std::f32::consts::FRAC_PI_2, epsilon = 1.0e-6
/// );
/// assert_relative_eq!(dq.translation().vector.y, 3.0, epsilon = 1.0e-6);
#[inline]
pub fn sclerp(&self, other: &Self, t: N) -> Self
where
N: RealField,
{
self.try_sclerp(other, t, N::default_epsilon())
.expect("DualQuaternion sclerp: ambiguous configuration.")
}
/// Computes the screw-linear interpolation between two unit quaternions or returns `None`
/// if both quaternions are approximately 180 degrees apart (in which case the interpolation is
/// not well-defined).
///
/// # Arguments
/// * `self`: the first quaternion to interpolate from.
/// * `other`: the second quaternion to interpolate toward.
/// * `t`: the interpolation parameter. Should be between 0 and 1.
/// * `epsilon`: the value below which the sinus of the angle separating both quaternion
/// must be to return `None`.
#[inline]
pub fn try_sclerp(&self, other: &Self, t: N, epsilon: N) -> Option<Self>
where
N: RealField,
{
let two = N::one() + N::one();
let half = N::one() / two;
// Invert one of the quaternions if we've got a longest-path
// interpolation.
let other = {
let dot_product = self.as_ref().real.coords.dot(&other.as_ref().real.coords);
if dot_product < N::zero() {
-other.clone()
} else {
other.clone()
}
};
let difference = self.as_ref().conjugate() * other.as_ref();
let norm_squared = difference.real.vector().norm_squared();
if relative_eq!(norm_squared, N::zero(), epsilon = epsilon) {
return None;
}
let inverse_norm_squared = N::one() / norm_squared;
let inverse_norm = inverse_norm_squared.sqrt();
let mut angle = two * difference.real.scalar().acos();
let mut pitch = -two * difference.dual.scalar() * inverse_norm;
let direction = difference.real.vector() * inverse_norm;
let moment = (difference.dual.vector()
- direction * (pitch * difference.real.scalar() * half))
* inverse_norm;
angle *= t;
pitch *= t;
let sin = (half * angle).sin();
let cos = (half * angle).cos();
let real = Quaternion::from_parts(cos, direction * sin);
let dual = Quaternion::from_parts(
-pitch * half * sin,
moment * sin + direction * (pitch * half * cos),
);
Some(
self * UnitDualQuaternion::new_unchecked(DualQuaternion::from_real_and_dual(
real, dual,
)),
)
}
/// Return the rotation part of this unit dual quaternion.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_4, 0.0, 0.0)
/// );
///
/// assert_relative_eq!(
/// dq.rotation().angle(), std::f32::consts::FRAC_PI_4, epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn rotation(&self) -> UnitQuaternion<N> {
Unit::new_unchecked(self.as_ref().real)
}
/// Return the translation part of this unit dual quaternion.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_4, 0.0, 0.0)
/// );
///
/// assert_relative_eq!(
/// dq.translation().vector, Vector3::new(0.0, 3.0, 0.0), epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn translation(&self) -> Translation3<N> {
let two = N::one() + N::one();
Translation3::from(
((self.as_ref().dual * self.as_ref().real.conjugate()) * two)
.vector()
.into_owned(),
)
}
/// Builds an isometry from this unit dual quaternion.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
/// let rotation = UnitQuaternion::from_euler_angles(std::f32::consts::PI, 0.0, 0.0);
/// let translation = Vector3::new(1.0, 3.0, 2.5);
/// let dq = UnitDualQuaternion::from_parts(
/// translation.into(),
/// rotation
/// );
/// let iso = dq.to_isometry();
///
/// assert_relative_eq!(iso.rotation.angle(), std::f32::consts::PI, epsilon = 1.0e-6);
/// assert_relative_eq!(iso.translation.vector, translation, epsilon = 1.0e-6);
/// ```
#[inline]
pub fn to_isometry(&self) -> Isometry3<N> {
Isometry3::from_parts(self.translation(), self.rotation())
}
/// Rotate and translate a point by this unit dual quaternion interpreted
/// as an isometry.
///
/// This is the same as the multiplication `self * pt`.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
/// );
/// let point = Point3::new(1.0, 2.0, 3.0);
///
/// assert_relative_eq!(
/// dq.transform_point(&point), Point3::new(1.0, 0.0, 2.0), epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn transform_point(&self, pt: &Point3<N>) -> Point3<N> {
self * pt
}
/// Rotate a vector by this unit dual quaternion, ignoring the translational
/// component.
///
/// This is the same as the multiplication `self * v`.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
/// );
/// let vector = Vector3::new(1.0, 2.0, 3.0);
///
/// assert_relative_eq!(
/// dq.transform_vector(&vector), Vector3::new(1.0, -3.0, 2.0), epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
self * v
}
/// Rotate and translate a point by the inverse of this unit quaternion.
///
/// This may be cheaper than inverting the unit dual quaternion and
/// transforming the point.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
/// );
/// let point = Point3::new(1.0, 2.0, 3.0);
///
/// assert_relative_eq!(
/// dq.inverse_transform_point(&point), Point3::new(1.0, 3.0, 1.0), epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn inverse_transform_point(&self, pt: &Point3<N>) -> Point3<N> {
self.inverse() * pt
}
/// Rotate a vector by the inverse of this unit quaternion, ignoring the
/// translational component.
///
/// This may be cheaper than inverting the unit dual quaternion and
/// transforming the vector.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
/// );
/// let vector = Vector3::new(1.0, 2.0, 3.0);
///
/// assert_relative_eq!(
/// dq.inverse_transform_vector(&vector), Vector3::new(1.0, 3.0, -2.0), epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn inverse_transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
self.inverse() * v
}
/// Rotate a unit vector by the inverse of this unit quaternion, ignoring
/// the translational component. This may be
/// cheaper than inverting the unit dual quaternion and transforming the
/// vector.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Unit, Vector3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
/// );
/// let vector = Unit::new_unchecked(Vector3::new(0.0, 1.0, 0.0));
///
/// assert_relative_eq!(
/// dq.inverse_transform_unit_vector(&vector),
/// Unit::new_unchecked(Vector3::new(0.0, 0.0, -1.0)),
/// epsilon = 1.0e-6
/// );
/// ```
#[inline]
pub fn inverse_transform_unit_vector(&self, v: &Unit<Vector3<N>>) -> Unit<Vector3<N>> {
self.inverse() * v
}
}
impl<N: SimdRealField + RealField> UnitDualQuaternion<N>
where
N::Element: SimdRealField,
{
/// Converts this unit dual quaternion interpreted as an isometry
/// into its equivalent homogeneous transformation matrix.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{Matrix4, UnitDualQuaternion, UnitQuaternion, Vector3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(1.0, 3.0, 2.0).into(),
/// UnitQuaternion::from_axis_angle(&Vector3::z_axis(), std::f32::consts::FRAC_PI_6)
/// );
/// let expected = Matrix4::new(0.8660254, -0.5, 0.0, 1.0,
/// 0.5, 0.8660254, 0.0, 3.0,
/// 0.0, 0.0, 1.0, 2.0,
/// 0.0, 0.0, 0.0, 1.0);
///
/// assert_relative_eq!(dq.to_homogeneous(), expected, epsilon = 1.0e-6);
/// ```
#[inline]
pub fn to_homogeneous(&self) -> Matrix4<N> {
self.to_isometry().to_homogeneous()
}
}
impl<N: RealField> Default for UnitDualQuaternion<N> {
fn default() -> Self {
Self::identity()
}
}
impl<N: RealField + fmt::Display> fmt::Display for UnitDualQuaternion<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(axis) = self.rotation().axis() {
let axis = axis.into_inner();
write!(
f,
"UnitDualQuaternion translation: {} angle: {} axis: ({}, {}, {})",
self.translation().vector,
self.rotation().angle(),
axis[0],
axis[1],
axis[2]
)
} else {
write!(
f,
"UnitDualQuaternion translation: {} angle: {} axis: (undefined)",
self.translation().vector,
self.rotation().angle()
)
}
}
}
impl<N: RealField + AbsDiffEq<Epsilon = N>> AbsDiffEq for UnitDualQuaternion<N> {
type Epsilon = N;
#[inline]
fn default_epsilon() -> Self::Epsilon {
N::default_epsilon()
}
#[inline]
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
self.as_ref().abs_diff_eq(other.as_ref(), epsilon)
}
}
impl<N: RealField + RelativeEq<Epsilon = N>> RelativeEq for UnitDualQuaternion<N> {
#[inline]
fn default_max_relative() -> Self::Epsilon {
N::default_max_relative()
}
#[inline]
fn relative_eq(
&self,
other: &Self,
epsilon: Self::Epsilon,
max_relative: Self::Epsilon,
) -> bool {
self.as_ref()
.relative_eq(other.as_ref(), epsilon, max_relative)
}
}
impl<N: RealField + UlpsEq<Epsilon = N>> UlpsEq for UnitDualQuaternion<N> {
#[inline]
fn default_max_ulps() -> u32 {
N::default_max_ulps()
}
#[inline]
fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
self.as_ref().ulps_eq(other.as_ref(), epsilon, max_ulps)
}
}

View File

@ -0,0 +1,324 @@
use num::Zero;
use alga::general::{
AbstractGroup, AbstractGroupAbelian, AbstractLoop, AbstractMagma, AbstractModule,
AbstractMonoid, AbstractQuasigroup, AbstractSemigroup, Additive, Id, Identity, Module,
Multiplicative, RealField, TwoSidedInverse,
};
use alga::linear::{
AffineTransformation, DirectIsometry, FiniteDimVectorSpace, Isometry, NormedSpace,
ProjectiveTransformation, Similarity, Transformation, VectorSpace,
};
use crate::base::Vector3;
use crate::geometry::{
DualQuaternion, Point3, Quaternion, Translation3, UnitDualQuaternion, UnitQuaternion,
};
impl<N: RealField + simba::scalar::RealField> Identity<Multiplicative> for DualQuaternion<N> {
#[inline]
fn identity() -> Self {
Self::identity()
}
}
impl<N: RealField + simba::scalar::RealField> Identity<Additive> for DualQuaternion<N> {
#[inline]
fn identity() -> Self {
Self::zero()
}
}
impl<N: RealField + simba::scalar::RealField> AbstractMagma<Multiplicative> for DualQuaternion<N> {
#[inline]
fn operate(&self, rhs: &Self) -> Self {
self * rhs
}
}
impl<N: RealField + simba::scalar::RealField> AbstractMagma<Additive> for DualQuaternion<N> {
#[inline]
fn operate(&self, rhs: &Self) -> Self {
self + rhs
}
}
impl<N: RealField + simba::scalar::RealField> TwoSidedInverse<Additive> for DualQuaternion<N> {
#[inline]
fn two_sided_inverse(&self) -> Self {
-self
}
}
macro_rules! impl_structures(
($DualQuaternion: ident; $($marker: ident<$operator: ident>),* $(,)*) => {$(
impl<N: RealField + simba::scalar::RealField> $marker<$operator> for $DualQuaternion<N> { }
)*}
);
impl_structures!(
DualQuaternion;
AbstractSemigroup<Multiplicative>,
AbstractMonoid<Multiplicative>,
AbstractSemigroup<Additive>,
AbstractQuasigroup<Additive>,
AbstractMonoid<Additive>,
AbstractLoop<Additive>,
AbstractGroup<Additive>,
AbstractGroupAbelian<Additive>
);
/*
*
* Vector space.
*
*/
impl<N: RealField + simba::scalar::RealField> AbstractModule for DualQuaternion<N> {
type AbstractRing = N;
#[inline]
fn multiply_by(&self, n: N) -> Self {
self * n
}
}
impl<N: RealField + simba::scalar::RealField> Module for DualQuaternion<N> {
type Ring = N;
}
impl<N: RealField + simba::scalar::RealField> VectorSpace for DualQuaternion<N> {
type Field = N;
}
impl<N: RealField + simba::scalar::RealField> FiniteDimVectorSpace for DualQuaternion<N> {
#[inline]
fn dimension() -> usize {
8
}
#[inline]
fn canonical_basis_element(i: usize) -> Self {
if i < 4 {
DualQuaternion::from_real_and_dual(
Quaternion::canonical_basis_element(i),
Quaternion::zero(),
)
} else {
DualQuaternion::from_real_and_dual(
Quaternion::zero(),
Quaternion::canonical_basis_element(i - 4),
)
}
}
#[inline]
fn dot(&self, other: &Self) -> N {
self.real.dot(&other.real) + self.dual.dot(&other.dual)
}
#[inline]
unsafe fn component_unchecked(&self, i: usize) -> &N {
self.as_ref().get_unchecked(i)
}
#[inline]
unsafe fn component_unchecked_mut(&mut self, i: usize) -> &mut N {
self.as_mut().get_unchecked_mut(i)
}
}
impl<N: RealField + simba::scalar::RealField> NormedSpace for DualQuaternion<N> {
type RealField = N;
type ComplexField = N;
#[inline]
fn norm_squared(&self) -> N {
self.real.norm_squared()
}
#[inline]
fn norm(&self) -> N {
self.real.norm()
}
#[inline]
fn normalize(&self) -> Self {
self.normalize()
}
#[inline]
fn normalize_mut(&mut self) -> N {
self.normalize_mut()
}
#[inline]
fn try_normalize(&self, min_norm: N) -> Option<Self> {
let real_norm = self.real.norm();
if real_norm > min_norm {
Some(Self::from_real_and_dual(
self.real / real_norm,
self.dual / real_norm,
))
} else {
None
}
}
#[inline]
fn try_normalize_mut(&mut self, min_norm: N) -> Option<N> {
let real_norm = self.real.norm();
if real_norm > min_norm {
self.real /= real_norm;
self.dual /= real_norm;
Some(real_norm)
} else {
None
}
}
}
/*
*
* Implementations for UnitDualQuaternion.
*
*/
impl<N: RealField + simba::scalar::RealField> Identity<Multiplicative> for UnitDualQuaternion<N> {
#[inline]
fn identity() -> Self {
Self::identity()
}
}
impl<N: RealField + simba::scalar::RealField> AbstractMagma<Multiplicative>
for UnitDualQuaternion<N>
{
#[inline]
fn operate(&self, rhs: &Self) -> Self {
self * rhs
}
}
impl<N: RealField + simba::scalar::RealField> TwoSidedInverse<Multiplicative>
for UnitDualQuaternion<N>
{
#[inline]
fn two_sided_inverse(&self) -> Self {
self.inverse()
}
#[inline]
fn two_sided_inverse_mut(&mut self) {
self.inverse_mut()
}
}
impl_structures!(
UnitDualQuaternion;
AbstractSemigroup<Multiplicative>,
AbstractQuasigroup<Multiplicative>,
AbstractMonoid<Multiplicative>,
AbstractLoop<Multiplicative>,
AbstractGroup<Multiplicative>
);
impl<N: RealField + simba::scalar::RealField> Transformation<Point3<N>> for UnitDualQuaternion<N> {
#[inline]
fn transform_point(&self, pt: &Point3<N>) -> Point3<N> {
self.transform_point(pt)
}
#[inline]
fn transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
self.transform_vector(v)
}
}
impl<N: RealField + simba::scalar::RealField> ProjectiveTransformation<Point3<N>>
for UnitDualQuaternion<N>
{
#[inline]
fn inverse_transform_point(&self, pt: &Point3<N>) -> Point3<N> {
self.inverse_transform_point(pt)
}
#[inline]
fn inverse_transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
self.inverse_transform_vector(v)
}
}
impl<N: RealField + simba::scalar::RealField> AffineTransformation<Point3<N>>
for UnitDualQuaternion<N>
{
type Rotation = UnitQuaternion<N>;
type NonUniformScaling = Id;
type Translation = Translation3<N>;
#[inline]
fn decompose(&self) -> (Self::Translation, Self::Rotation, Id, Self::Rotation) {
(
self.translation(),
self.rotation(),
Id::new(),
UnitQuaternion::identity(),
)
}
#[inline]
fn append_translation(&self, translation: &Self::Translation) -> Self {
self * Self::from_parts(translation.clone(), UnitQuaternion::identity())
}
#[inline]
fn prepend_translation(&self, translation: &Self::Translation) -> Self {
Self::from_parts(translation.clone(), UnitQuaternion::identity()) * self
}
#[inline]
fn append_rotation(&self, r: &Self::Rotation) -> Self {
r * self
}
#[inline]
fn prepend_rotation(&self, r: &Self::Rotation) -> Self {
self * r
}
#[inline]
fn append_scaling(&self, _: &Self::NonUniformScaling) -> Self {
self.clone()
}
#[inline]
fn prepend_scaling(&self, _: &Self::NonUniformScaling) -> Self {
self.clone()
}
}
impl<N: RealField + simba::scalar::RealField> Similarity<Point3<N>> for UnitDualQuaternion<N> {
type Scaling = Id;
#[inline]
fn translation(&self) -> Translation3<N> {
self.translation()
}
#[inline]
fn rotation(&self) -> UnitQuaternion<N> {
self.rotation()
}
#[inline]
fn scaling(&self) -> Id {
Id::new()
}
}
macro_rules! marker_impl(
($($Trait: ident),*) => {$(
impl<N: RealField + simba::scalar::RealField> $Trait<Point3<N>> for UnitDualQuaternion<N> { }
)*}
);
marker_impl!(Isometry, DirectIsometry);

View File

@ -1,6 +1,12 @@
use crate::{DualQuaternion, Quaternion, SimdRealField}; use crate::{
DualQuaternion, Isometry3, Quaternion, Scalar, SimdRealField, Translation3, UnitDualQuaternion,
UnitQuaternion,
};
use num::{One, Zero};
#[cfg(feature = "arbitrary")]
use quickcheck::{Arbitrary, Gen};
impl<N: SimdRealField> DualQuaternion<N> { impl<N: Scalar> DualQuaternion<N> {
/// Creates a dual quaternion from its rotation and translation components. /// Creates a dual quaternion from its rotation and translation components.
/// ///
/// # Example /// # Example
@ -16,7 +22,8 @@ impl<N: SimdRealField> DualQuaternion<N> {
pub fn from_real_and_dual(real: Quaternion<N>, dual: Quaternion<N>) -> Self { pub fn from_real_and_dual(real: Quaternion<N>, dual: Quaternion<N>) -> Self {
Self { real, dual } Self { real, dual }
} }
/// The dual quaternion multiplicative identity
/// The dual quaternion multiplicative identity.
/// ///
/// # Example /// # Example
/// ///
@ -33,10 +40,183 @@ impl<N: SimdRealField> DualQuaternion<N> {
/// assert_eq!(dq2 * dq1, dq2); /// assert_eq!(dq2 * dq1, dq2);
/// ``` /// ```
#[inline] #[inline]
pub fn identity() -> Self { pub fn identity() -> Self
where
N: SimdRealField,
{
Self::from_real_and_dual( Self::from_real_and_dual(
Quaternion::from_real(N::one()), Quaternion::from_real(N::one()),
Quaternion::from_real(N::zero()), Quaternion::from_real(N::zero()),
) )
} }
} }
impl<N: SimdRealField> DualQuaternion<N>
where
N::Element: SimdRealField,
{
/// Creates a dual quaternion from only its real part, with no translation
/// component.
///
/// # Example
/// ```
/// # use nalgebra::{DualQuaternion, Quaternion};
/// let rot = Quaternion::new(1.0, 2.0, 3.0, 4.0);
///
/// let dq = DualQuaternion::from_real(rot);
/// assert_eq!(dq.real.w, 1.0);
/// assert_eq!(dq.dual.w, 0.0);
/// ```
#[inline]
pub fn from_real(real: Quaternion<N>) -> Self {
Self {
real,
dual: Quaternion::zero(),
}
}
}
impl<N: SimdRealField> One for DualQuaternion<N>
where
N::Element: SimdRealField,
{
#[inline]
fn one() -> Self {
Self::identity()
}
}
impl<N: SimdRealField> Zero for DualQuaternion<N>
where
N::Element: SimdRealField,
{
#[inline]
fn zero() -> Self {
DualQuaternion::from_real_and_dual(Quaternion::zero(), Quaternion::zero())
}
#[inline]
fn is_zero(&self) -> bool {
self.real.is_zero() && self.dual.is_zero()
}
}
#[cfg(feature = "arbitrary")]
impl<N> Arbitrary for DualQuaternion<N>
where
N: SimdRealField + Arbitrary + Send,
N::Element: SimdRealField,
{
#[inline]
fn arbitrary(rng: &mut Gen) -> Self {
Self::from_real_and_dual(Arbitrary::arbitrary(rng), Arbitrary::arbitrary(rng))
}
}
impl<N: SimdRealField> UnitDualQuaternion<N> {
/// The unit dual quaternion multiplicative identity, which also represents
/// the identity transformation as an isometry.
///
/// ```
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
/// let ident = UnitDualQuaternion::identity();
/// let point = Point3::new(1.0, -4.3, 3.33);
///
/// assert_eq!(ident * point, point);
/// assert_eq!(ident, ident.inverse());
/// ```
#[inline]
pub fn identity() -> Self {
Self::new_unchecked(DualQuaternion::identity())
}
}
impl<N: SimdRealField> UnitDualQuaternion<N>
where
N::Element: SimdRealField,
{
/// Return a dual quaternion representing the translation and orientation
/// given by the provided rotation quaternion and translation vector.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
/// let dq = UnitDualQuaternion::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
/// );
/// let point = Point3::new(1.0, 2.0, 3.0);
///
/// assert_relative_eq!(dq * point, Point3::new(1.0, 0.0, 2.0), epsilon = 1.0e-6);
/// ```
#[inline]
pub fn from_parts(translation: Translation3<N>, rotation: UnitQuaternion<N>) -> Self {
let half: N = crate::convert(0.5f64);
UnitDualQuaternion::new_unchecked(DualQuaternion {
real: rotation.clone().into_inner(),
dual: Quaternion::from_parts(N::zero(), translation.vector)
* rotation.clone().into_inner()
* half,
})
}
/// Return a unit dual quaternion representing the translation and orientation
/// given by the provided isometry.
///
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{Isometry3, UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
/// let iso = Isometry3::from_parts(
/// Vector3::new(0.0, 3.0, 0.0).into(),
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
/// );
/// let dq = UnitDualQuaternion::from_isometry(&iso);
/// let point = Point3::new(1.0, 2.0, 3.0);
///
/// assert_relative_eq!(dq * point, iso * point, epsilon = 1.0e-6);
/// ```
#[inline]
pub fn from_isometry(isometry: &Isometry3<N>) -> Self {
UnitDualQuaternion::from_parts(isometry.translation, isometry.rotation)
}
/// Creates a dual quaternion from a unit quaternion rotation.
///
/// # Example
/// ```
/// # #[macro_use] extern crate approx;
/// # use nalgebra::{UnitQuaternion, UnitDualQuaternion, Quaternion};
/// let q = Quaternion::new(1.0, 2.0, 3.0, 4.0);
/// let rot = UnitQuaternion::new_normalize(q);
///
/// let dq = UnitDualQuaternion::from_rotation(rot);
/// assert_relative_eq!(dq.as_ref().real.norm(), 1.0, epsilon = 1.0e-6);
/// assert_eq!(dq.as_ref().dual.norm(), 0.0);
/// ```
#[inline]
pub fn from_rotation(rotation: UnitQuaternion<N>) -> Self {
Self::new_unchecked(DualQuaternion::from_real(rotation.into_inner()))
}
}
impl<N: SimdRealField> One for UnitDualQuaternion<N>
where
N::Element: SimdRealField,
{
#[inline]
fn one() -> Self {
Self::identity()
}
}
#[cfg(feature = "arbitrary")]
impl<N> Arbitrary for UnitDualQuaternion<N>
where
N: SimdRealField + Arbitrary + Send,
N::Element: SimdRealField,
{
#[inline]
fn arbitrary(rng: &mut Gen) -> Self {
Self::new_normalize(Arbitrary::arbitrary(rng))
}
}

View File

@ -0,0 +1,188 @@
use simba::scalar::{RealField, SubsetOf, SupersetOf};
use simba::simd::SimdRealField;
use crate::base::dimension::U3;
use crate::base::{Matrix4, Vector4};
use crate::geometry::{
DualQuaternion, Isometry3, Similarity3, SuperTCategoryOf, TAffine, Transform, Translation3,
UnitDualQuaternion, UnitQuaternion,
};
/*
* This file provides the following conversions:
* =============================================
*
* DualQuaternion -> DualQuaternion
* UnitDualQuaternion -> UnitDualQuaternion
* UnitDualQuaternion -> Isometry<U3>
* UnitDualQuaternion -> Similarity<U3>
* UnitDualQuaternion -> Transform<U3>
* UnitDualQuaternion -> Matrix<U4> (homogeneous)
*
* NOTE:
* UnitDualQuaternion -> DualQuaternion is already provided by: Unit<T> -> T
*/
impl<N1, N2> SubsetOf<DualQuaternion<N2>> for DualQuaternion<N1>
where
N1: SimdRealField,
N2: SimdRealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> DualQuaternion<N2> {
DualQuaternion::from_real_and_dual(self.real.to_superset(), self.dual.to_superset())
}
#[inline]
fn is_in_subset(dq: &DualQuaternion<N2>) -> bool {
crate::is_convertible::<_, Vector4<N1>>(&dq.real.coords)
&& crate::is_convertible::<_, Vector4<N1>>(&dq.dual.coords)
}
#[inline]
fn from_superset_unchecked(dq: &DualQuaternion<N2>) -> Self {
DualQuaternion::from_real_and_dual(
dq.real.to_subset_unchecked(),
dq.dual.to_subset_unchecked(),
)
}
}
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for UnitDualQuaternion<N1>
where
N1: SimdRealField,
N2: SimdRealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> UnitDualQuaternion<N2> {
UnitDualQuaternion::new_unchecked(self.as_ref().to_superset())
}
#[inline]
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
crate::is_convertible::<_, DualQuaternion<N1>>(dq.as_ref())
}
#[inline]
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
Self::new_unchecked(crate::convert_ref_unchecked(dq.as_ref()))
}
}
impl<N1, N2> SubsetOf<Isometry3<N2>> for UnitDualQuaternion<N1>
where
N1: RealField,
N2: RealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> Isometry3<N2> {
let dq: UnitDualQuaternion<N2> = self.to_superset();
let iso = dq.to_isometry();
crate::convert_unchecked(iso)
}
#[inline]
fn is_in_subset(iso: &Isometry3<N2>) -> bool {
crate::is_convertible::<_, UnitQuaternion<N1>>(&iso.rotation)
&& crate::is_convertible::<_, Translation3<N1>>(&iso.translation)
}
#[inline]
fn from_superset_unchecked(iso: &Isometry3<N2>) -> Self {
let dq = UnitDualQuaternion::<N2>::from_isometry(iso);
crate::convert_unchecked(dq)
}
}
impl<N1, N2> SubsetOf<Similarity3<N2>> for UnitDualQuaternion<N1>
where
N1: RealField,
N2: RealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> Similarity3<N2> {
Similarity3::from_isometry(crate::convert_ref(self), N2::one())
}
#[inline]
fn is_in_subset(sim: &Similarity3<N2>) -> bool {
sim.scaling() == N2::one()
}
#[inline]
fn from_superset_unchecked(sim: &Similarity3<N2>) -> Self {
crate::convert_ref_unchecked(&sim.isometry)
}
}
impl<N1, N2, C> SubsetOf<Transform<N2, U3, C>> for UnitDualQuaternion<N1>
where
N1: RealField,
N2: RealField + SupersetOf<N1>,
C: SuperTCategoryOf<TAffine>,
{
#[inline]
fn to_superset(&self) -> Transform<N2, U3, C> {
Transform::from_matrix_unchecked(self.to_homogeneous().to_superset())
}
#[inline]
fn is_in_subset(t: &Transform<N2, U3, C>) -> bool {
<Self as SubsetOf<_>>::is_in_subset(t.matrix())
}
#[inline]
fn from_superset_unchecked(t: &Transform<N2, U3, C>) -> Self {
Self::from_superset_unchecked(t.matrix())
}
}
impl<N1: RealField, N2: RealField + SupersetOf<N1>> SubsetOf<Matrix4<N2>>
for UnitDualQuaternion<N1>
{
#[inline]
fn to_superset(&self) -> Matrix4<N2> {
self.to_homogeneous().to_superset()
}
#[inline]
fn is_in_subset(m: &Matrix4<N2>) -> bool {
crate::is_convertible::<_, Isometry3<N1>>(m)
}
#[inline]
fn from_superset_unchecked(m: &Matrix4<N2>) -> Self {
let iso: Isometry3<N1> = crate::convert_ref_unchecked(m);
Self::from_isometry(&iso)
}
}
impl<N: SimdRealField + RealField> From<UnitDualQuaternion<N>> for Matrix4<N>
where
N::Element: SimdRealField,
{
#[inline]
fn from(dq: UnitDualQuaternion<N>) -> Self {
dq.to_homogeneous()
}
}
impl<N: SimdRealField> From<UnitDualQuaternion<N>> for Isometry3<N>
where
N::Element: SimdRealField,
{
#[inline]
fn from(dq: UnitDualQuaternion<N>) -> Self {
dq.to_isometry()
}
}
impl<N: SimdRealField> From<Isometry3<N>> for UnitDualQuaternion<N>
where
N::Element: SimdRealField,
{
#[inline]
fn from(iso: Isometry3<N>) -> Self {
Self::from_isometry(&iso)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -102,7 +102,7 @@ where
DefaultAllocator: Allocator<N, D>, DefaultAllocator: Allocator<N, D>,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(rng: &mut G) -> Self { fn arbitrary(rng: &mut Gen) -> Self {
Self::from_parts(Arbitrary::arbitrary(rng), Arbitrary::arbitrary(rng)) Self::from_parts(Arbitrary::arbitrary(rng), Arbitrary::arbitrary(rng))
} }
} }

View File

@ -6,7 +6,8 @@ use crate::base::dimension::{DimMin, DimName, DimNameAdd, DimNameSum, U1};
use crate::base::{DefaultAllocator, MatrixN, Scalar}; use crate::base::{DefaultAllocator, MatrixN, Scalar};
use crate::geometry::{ use crate::geometry::{
AbstractRotation, Isometry, Similarity, SuperTCategoryOf, TAffine, Transform, Translation, AbstractRotation, Isometry, Isometry3, Similarity, SuperTCategoryOf, TAffine, Transform,
Translation, UnitDualQuaternion, UnitQuaternion,
}; };
/* /*
@ -14,6 +15,7 @@ use crate::geometry::{
* ============================================= * =============================================
* *
* Isometry -> Isometry * Isometry -> Isometry
* Isometry3 -> UnitDualQuaternion
* Isometry -> Similarity * Isometry -> Similarity
* Isometry -> Transform * Isometry -> Transform
* Isometry -> Matrix (homogeneous) * Isometry -> Matrix (homogeneous)
@ -47,6 +49,30 @@ where
} }
} }
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for Isometry3<N1>
where
N1: RealField,
N2: RealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> UnitDualQuaternion<N2> {
let dq = UnitDualQuaternion::<N1>::from_isometry(self);
dq.to_superset()
}
#[inline]
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
crate::is_convertible::<_, UnitQuaternion<N1>>(&dq.rotation())
&& crate::is_convertible::<_, Translation<N1, _>>(&dq.translation())
}
#[inline]
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
let dq: UnitDualQuaternion<N1> = crate::convert_ref_unchecked(dq);
dq.to_isometry()
}
}
impl<N1, N2, D: DimName, R1, R2> SubsetOf<Similarity<N2, D, R2>> for Isometry<N1, D, R1> impl<N1, N2, D: DimName, R1, R2> SubsetOf<Similarity<N2, D, R2>> for Isometry<N1, D, R1>
where where
N1: RealField, N1: RealField,

View File

@ -36,7 +36,10 @@ mod quaternion_ops;
mod quaternion_simba; mod quaternion_simba;
mod dual_quaternion; mod dual_quaternion;
#[cfg(feature = "alga")]
mod dual_quaternion_alga;
mod dual_quaternion_construction; mod dual_quaternion_construction;
mod dual_quaternion_conversion;
mod dual_quaternion_ops; mod dual_quaternion_ops;
mod unit_complex; mod unit_complex;

View File

@ -705,7 +705,7 @@ impl<N: RealField + Arbitrary> Arbitrary for Orthographic3<N>
where where
Matrix4<N>: Send, Matrix4<N>: Send,
{ {
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
let left = Arbitrary::arbitrary(g); let left = Arbitrary::arbitrary(g);
let right = helper::reject(g, |x: &N| *x > left); let right = helper::reject(g, |x: &N| *x > left);
let bottom = Arbitrary::arbitrary(g); let bottom = Arbitrary::arbitrary(g);

View File

@ -283,7 +283,7 @@ where
#[cfg(feature = "arbitrary")] #[cfg(feature = "arbitrary")]
impl<N: RealField + Arbitrary> Arbitrary for Perspective3<N> { impl<N: RealField + Arbitrary> Arbitrary for Perspective3<N> {
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
let znear = Arbitrary::arbitrary(g); let znear = Arbitrary::arbitrary(g);
let zfar = helper::reject(g, |&x: &N| !(x - znear).is_zero()); let zfar = helper::reject(g, |&x: &N| !(x - znear).is_zero());
let aspect = helper::reject(g, |&x: &N| !x.is_zero()); let aspect = helper::reject(g, |&x: &N| !x.is_zero());

View File

@ -65,6 +65,24 @@ where
{ {
} }
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar, D: DimName> bytemuck::Zeroable for Point<N, D>
where
VectorN<N, D>: bytemuck::Zeroable,
DefaultAllocator: Allocator<N, D>,
{
}
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar, D: DimName> bytemuck::Pod for Point<N, D>
where
N: Copy,
VectorN<N, D>: bytemuck::Pod,
DefaultAllocator: Allocator<N, D>,
<DefaultAllocator as Allocator<N, D>>::Buffer: Copy,
{
}
#[cfg(feature = "serde-serialize")] #[cfg(feature = "serde-serialize")]
impl<N: Scalar, D: DimName> Serialize for Point<N, D> impl<N: Scalar, D: DimName> Serialize for Point<N, D>
where where
@ -181,7 +199,12 @@ where
D: DimNameAdd<U1>, D: DimNameAdd<U1>,
DefaultAllocator: Allocator<N, DimNameSum<D, U1>>, DefaultAllocator: Allocator<N, DimNameSum<D, U1>>,
{ {
let mut res = unsafe { VectorN::<_, DimNameSum<D, U1>>::new_uninitialized() }; let mut res = unsafe {
crate::unimplemented_or_uninitialized_generic!(
<DimNameSum<D, U1> as DimName>::name(),
U1
)
};
res.fixed_slice_mut::<D, U1>(0, 0).copy_from(&self.coords); res.fixed_slice_mut::<D, U1>(0, 0).copy_from(&self.coords);
res[(D::dim(), 0)] = N::one(); res[(D::dim(), 0)] = N::one();

View File

@ -24,7 +24,10 @@ where
/// Creates a new point with uninitialized coordinates. /// Creates a new point with uninitialized coordinates.
#[inline] #[inline]
pub unsafe fn new_uninitialized() -> Self { pub unsafe fn new_uninitialized() -> Self {
Self::from(VectorN::new_uninitialized()) Self::from(crate::unimplemented_or_uninitialized_generic!(
D::name(),
U1
))
} }
/// Creates a new point with all coordinates equal to zero. /// Creates a new point with all coordinates equal to zero.
@ -153,7 +156,7 @@ where
<DefaultAllocator as Allocator<N, D>>::Buffer: Send, <DefaultAllocator as Allocator<N, D>>::Buffer: Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
Self::from(VectorN::arbitrary(g)) Self::from(VectorN::arbitrary(g))
} }
} }

View File

@ -40,6 +40,17 @@ impl<N: Scalar + Zero> Default for Quaternion<N> {
} }
} }
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar> bytemuck::Zeroable for Quaternion<N> where Vector4<N>: bytemuck::Zeroable {}
#[cfg(feature = "bytemuck")]
unsafe impl<N: Scalar> bytemuck::Pod for Quaternion<N>
where
Vector4<N>: bytemuck::Pod,
N: Copy,
{
}
#[cfg(feature = "abomonation-serialize")] #[cfg(feature = "abomonation-serialize")]
impl<N: Scalar> Abomonation for Quaternion<N> impl<N: Scalar> Abomonation for Quaternion<N>
where where
@ -1542,6 +1553,17 @@ where
pub fn inverse_transform_unit_vector(&self, v: &Unit<Vector3<N>>) -> Unit<Vector3<N>> { pub fn inverse_transform_unit_vector(&self, v: &Unit<Vector3<N>>) -> Unit<Vector3<N>> {
self.inverse() * v self.inverse() * v
} }
/// Appends to `self` a rotation given in the axis-angle form, using a linearized formulation.
///
/// This is faster, but approximate, way to compute `UnitQuaternion::new(axisangle) * self`.
#[inline]
pub fn append_axisangle_linearized(&self, axisangle: &Vector3<N>) -> Self {
let half: N = crate::convert(0.5);
let q1 = self.into_inner();
let q2 = Quaternion::from_imag(axisangle * half);
Unit::new_normalize(q1 + q2 * q1)
}
} }
impl<N: RealField> Default for UnitQuaternion<N> { impl<N: RealField> Default for UnitQuaternion<N> {

View File

@ -160,7 +160,7 @@ where
Owned<N, U4>: Send, Owned<N, U4>: Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
Self::new( Self::new(
N::arbitrary(g), N::arbitrary(g),
N::arbitrary(g), N::arbitrary(g),
@ -266,6 +266,17 @@ where
Self::new_unchecked(q) Self::new_unchecked(q)
} }
/// Builds an unit quaternion from a basis assumed to be orthonormal.
///
/// In order to get a valid unit-quaternion, the input must be an
/// orthonormal basis, i.e., all vectors are normalized, and the are
/// all orthogonal to each other. These invariants are not checked
/// by this method.
pub fn from_basis_unchecked(basis: &[Vector3<N>; 3]) -> Self {
let rot = Rotation3::from_basis_unchecked(basis);
Self::from_rotation_matrix(&rot)
}
/// Builds an unit quaternion from a rotation matrix. /// Builds an unit quaternion from a rotation matrix.
/// ///
/// # Example /// # Example
@ -834,7 +845,7 @@ where
Owned<N, U3>: Send, Owned<N, U3>: Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
let axisangle = Vector3::arbitrary(g); let axisangle = Vector3::arbitrary(g);
Self::from_scaled_axis(axisangle) Self::from_scaled_axis(axisangle)
} }

View File

@ -10,7 +10,7 @@ use crate::base::dimension::U3;
use crate::base::{Matrix3, Matrix4, Scalar, Vector4}; use crate::base::{Matrix3, Matrix4, Scalar, Vector4};
use crate::geometry::{ use crate::geometry::{
AbstractRotation, Isometry, Quaternion, Rotation, Rotation3, Similarity, SuperTCategoryOf, AbstractRotation, Isometry, Quaternion, Rotation, Rotation3, Similarity, SuperTCategoryOf,
TAffine, Transform, Translation, UnitQuaternion, TAffine, Transform, Translation, UnitDualQuaternion, UnitQuaternion,
}; };
/* /*
@ -21,6 +21,7 @@ use crate::geometry::{
* UnitQuaternion -> UnitQuaternion * UnitQuaternion -> UnitQuaternion
* UnitQuaternion -> Rotation<U3> * UnitQuaternion -> Rotation<U3>
* UnitQuaternion -> Isometry<U3> * UnitQuaternion -> Isometry<U3>
* UnitQuaternion -> UnitDualQuaternion
* UnitQuaternion -> Similarity<U3> * UnitQuaternion -> Similarity<U3>
* UnitQuaternion -> Transform<U3> * UnitQuaternion -> Transform<U3>
* UnitQuaternion -> Matrix<U4> (homogeneous) * UnitQuaternion -> Matrix<U4> (homogeneous)
@ -121,6 +122,28 @@ where
} }
} }
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for UnitQuaternion<N1>
where
N1: RealField,
N2: RealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> UnitDualQuaternion<N2> {
let q: UnitQuaternion<N2> = crate::convert_ref(self);
UnitDualQuaternion::from_rotation(q)
}
#[inline]
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
dq.translation().vector.is_zero()
}
#[inline]
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
crate::convert_unchecked(dq.rotation())
}
}
impl<N1, N2, R> SubsetOf<Similarity<N2, U3, R>> for UnitQuaternion<N1> impl<N1, N2, R> SubsetOf<Similarity<N2, U3, R>> for UnitQuaternion<N1>
where where
N1: RealField, N1: RealField,

View File

@ -12,7 +12,7 @@ use crate::base::{DefaultAllocator, Matrix2, Matrix3, Matrix4, MatrixN, Scalar};
use crate::geometry::{ use crate::geometry::{
AbstractRotation, Isometry, Rotation, Rotation2, Rotation3, Similarity, SuperTCategoryOf, AbstractRotation, Isometry, Rotation, Rotation2, Rotation3, Similarity, SuperTCategoryOf,
TAffine, Transform, Translation, UnitComplex, UnitQuaternion, TAffine, Transform, Translation, UnitComplex, UnitDualQuaternion, UnitQuaternion,
}; };
/* /*
@ -21,6 +21,7 @@ use crate::geometry::{
* *
* Rotation -> Rotation * Rotation -> Rotation
* Rotation3 -> UnitQuaternion * Rotation3 -> UnitQuaternion
* Rotation3 -> UnitDualQuaternion
* Rotation2 -> UnitComplex * Rotation2 -> UnitComplex
* Rotation -> Isometry * Rotation -> Isometry
* Rotation -> Similarity * Rotation -> Similarity
@ -75,6 +76,31 @@ where
} }
} }
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for Rotation3<N1>
where
N1: RealField,
N2: RealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> UnitDualQuaternion<N2> {
let q = UnitQuaternion::<N1>::from_rotation_matrix(self);
let dq = UnitDualQuaternion::from_rotation(q);
dq.to_superset()
}
#[inline]
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
crate::is_convertible::<_, UnitQuaternion<N1>>(&dq.rotation())
&& dq.translation().vector.is_zero()
}
#[inline]
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
let dq: UnitDualQuaternion<N1> = crate::convert_ref_unchecked(dq);
dq.rotation().to_rotation_matrix()
}
}
impl<N1, N2> SubsetOf<UnitComplex<N2>> for Rotation2<N1> impl<N1, N2> SubsetOf<UnitComplex<N2>> for Rotation2<N1>
where where
N1: RealField, N1: RealField,

View File

@ -12,7 +12,7 @@ use std::ops::Neg;
use crate::base::dimension::{U1, U2, U3}; use crate::base::dimension::{U1, U2, U3};
use crate::base::storage::Storage; use crate::base::storage::Storage;
use crate::base::{Matrix2, Matrix3, MatrixN, Unit, Vector, Vector1, Vector3, VectorN}; use crate::base::{Matrix2, Matrix3, MatrixN, Unit, Vector, Vector1, Vector2, Vector3, VectorN};
use crate::geometry::{Rotation2, Rotation3, UnitComplex, UnitQuaternion}; use crate::geometry::{Rotation2, Rotation3, UnitComplex, UnitQuaternion};
@ -53,6 +53,17 @@ impl<N: SimdRealField> Rotation2<N> {
/// # Construction from an existing 2D matrix or rotations /// # Construction from an existing 2D matrix or rotations
impl<N: SimdRealField> Rotation2<N> { impl<N: SimdRealField> Rotation2<N> {
/// Builds a rotation from a basis assumed to be orthonormal.
///
/// In order to get a valid unit-quaternion, the input must be an
/// orthonormal basis, i.e., all vectors are normalized, and the are
/// all orthogonal to each other. These invariants are not checked
/// by this method.
pub fn from_basis_unchecked(basis: &[Vector2<N>; 2]) -> Self {
let mat = Matrix2::from_columns(&basis[..]);
Self::from_matrix_unchecked(mat)
}
/// Builds a rotation matrix by extracting the rotation part of the given transformation `m`. /// Builds a rotation matrix by extracting the rotation part of the given transformation `m`.
/// ///
/// This is an iterative method. See `.from_matrix_eps` to provide mover /// This is an iterative method. See `.from_matrix_eps` to provide mover
@ -264,7 +275,7 @@ where
Owned<N, U2, U2>: Send, Owned<N, U2, U2>: Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
Self::new(N::arbitrary(g)) Self::new(N::arbitrary(g))
} }
} }
@ -655,6 +666,17 @@ where
} }
} }
/// Builds a rotation from a basis assumed to be orthonormal.
///
/// In order to get a valid unit-quaternion, the input must be an
/// orthonormal basis, i.e., all vectors are normalized, and the are
/// all orthogonal to each other. These invariants are not checked
/// by this method.
pub fn from_basis_unchecked(basis: &[Vector3<N>; 3]) -> Self {
let mat = Matrix3::from_columns(&basis[..]);
Self::from_matrix_unchecked(mat)
}
/// Builds a rotation matrix by extracting the rotation part of the given transformation `m`. /// Builds a rotation matrix by extracting the rotation part of the given transformation `m`.
/// ///
/// This is an iterative method. See `.from_matrix_eps` to provide mover /// This is an iterative method. See `.from_matrix_eps` to provide mover
@ -939,7 +961,7 @@ where
Owned<N, U3>: Send, Owned<N, U3>: Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(g: &mut G) -> Self { fn arbitrary(g: &mut Gen) -> Self {
Self::new(VectorN::arbitrary(g)) Self::new(VectorN::arbitrary(g))
} }
} }

View File

@ -114,7 +114,7 @@ where
Owned<N, D>: Send, Owned<N, D>: Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(rng: &mut G) -> Self { fn arbitrary(rng: &mut Gen) -> Self {
let mut s: N = Arbitrary::arbitrary(rng); let mut s: N = Arbitrary::arbitrary(rng);
while s.is_zero() { while s.is_zero() {
s = Arbitrary::arbitrary(rng) s = Arbitrary::arbitrary(rng)

View File

@ -61,13 +61,13 @@ where
} }
#[cfg(feature = "arbitrary")] #[cfg(feature = "arbitrary")]
impl<N: Scalar + Arbitrary, D: DimName> Arbitrary for Translation<N, D> impl<N: Scalar + Arbitrary + Send, D: DimName> Arbitrary for Translation<N, D>
where where
DefaultAllocator: Allocator<N, D>, DefaultAllocator: Allocator<N, D>,
Owned<N, D>: Send, Owned<N, D>: Send,
{ {
#[inline] #[inline]
fn arbitrary<G: Gen>(rng: &mut G) -> Self { fn arbitrary(rng: &mut Gen) -> Self {
let v: VectorN<N, D> = Arbitrary::arbitrary(rng); let v: VectorN<N, D> = Arbitrary::arbitrary(rng);
Self::from(v) Self::from(v)
} }

View File

@ -9,6 +9,7 @@ use crate::base::{DefaultAllocator, MatrixN, Scalar, VectorN};
use crate::geometry::{ use crate::geometry::{
AbstractRotation, Isometry, Similarity, SuperTCategoryOf, TAffine, Transform, Translation, AbstractRotation, Isometry, Similarity, SuperTCategoryOf, TAffine, Transform, Translation,
Translation3, UnitDualQuaternion, UnitQuaternion,
}; };
/* /*
@ -17,6 +18,7 @@ use crate::geometry::{
* *
* Translation -> Translation * Translation -> Translation
* Translation -> Isometry * Translation -> Isometry
* Translation3 -> UnitDualQuaternion
* Translation -> Similarity * Translation -> Similarity
* Translation -> Transform * Translation -> Transform
* Translation -> Matrix (homogeneous) * Translation -> Matrix (homogeneous)
@ -69,6 +71,30 @@ where
} }
} }
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for Translation3<N1>
where
N1: RealField,
N2: RealField + SupersetOf<N1>,
{
#[inline]
fn to_superset(&self) -> UnitDualQuaternion<N2> {
let dq = UnitDualQuaternion::<N1>::from_parts(self.clone(), UnitQuaternion::identity());
dq.to_superset()
}
#[inline]
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
crate::is_convertible::<_, Translation<N1, _>>(&dq.translation())
&& dq.rotation() == UnitQuaternion::identity()
}
#[inline]
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
let dq: UnitDualQuaternion<N1> = crate::convert_ref_unchecked(dq);
dq.translation()
}
}
impl<N1, N2, D: DimName, R> SubsetOf<Similarity<N2, D, R>> for Translation<N1, D> impl<N1, N2, D: DimName, R> SubsetOf<Similarity<N2, D, R>> for Translation<N1, D>
where where
N1: RealField, N1: RealField,

Some files were not shown because too many files have changed in this diff Show More