commit
39ef8b43cf
|
@ -1,119 +0,0 @@
|
|||
version: 2.1
|
||||
|
||||
executors:
|
||||
rust-nightly-executor:
|
||||
docker:
|
||||
- image: rustlang/rust:nightly
|
||||
rust-executor:
|
||||
docker:
|
||||
- image: rust:latest
|
||||
|
||||
|
||||
jobs:
|
||||
check-fmt:
|
||||
executor: rust-executor
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: install rustfmt
|
||||
command: rustup component add rustfmt
|
||||
- run:
|
||||
name: check formatting
|
||||
command: cargo fmt -- --check
|
||||
clippy:
|
||||
executor: rust-executor
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: install clippy
|
||||
command: rustup component add clippy
|
||||
- run:
|
||||
name: clippy
|
||||
command: cargo clippy
|
||||
build-native:
|
||||
executor: rust-executor
|
||||
steps:
|
||||
- checkout
|
||||
- run: apt-get update
|
||||
- run: apt-get install -y cmake gfortran libblas-dev liblapack-dev
|
||||
- run:
|
||||
name: build --no-default-feature
|
||||
command: cargo build --no-default-features;
|
||||
- run:
|
||||
name: build (default features)
|
||||
command: cargo build;
|
||||
- run:
|
||||
name: build --all-features
|
||||
command: cargo build --all-features
|
||||
- run:
|
||||
name: build nalgebra-glm
|
||||
command: cargo build -p nalgebra-glm --all-features
|
||||
- run:
|
||||
name: build nalgebra-lapack
|
||||
command: cd nalgebra-lapack; cargo build
|
||||
test-native:
|
||||
executor: rust-executor
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: test
|
||||
command: cargo test --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm
|
||||
- run:
|
||||
name: test nalgebra-glm
|
||||
command: cargo test -p nalgebra-glm --features arbitrary --features serde-serialize --features abomonation-serialize --features sparse --features debug --features io --features compare --features libm
|
||||
build-wasm:
|
||||
executor: rust-executor
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: install cargo-web
|
||||
command: cargo install -f cargo-web;
|
||||
- run:
|
||||
name: build --all-features
|
||||
command: cargo web build --verbose --target wasm32-unknown-unknown;
|
||||
- run:
|
||||
name: build nalgebra-glm
|
||||
command: cargo build -p nalgebra-glm --all-features
|
||||
build-no-std:
|
||||
executor: rust-nightly-executor
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: install xargo
|
||||
command: cp .circleci/Xargo.toml .; rustup component add rust-src; cargo install -f xargo;
|
||||
- run:
|
||||
name: build
|
||||
command: xargo build --verbose --no-default-features --target=x86_64-unknown-linux-gnu;
|
||||
- run:
|
||||
name: build --features alloc
|
||||
command: xargo build --verbose --no-default-features --features alloc --target=x86_64-unknown-linux-gnu;
|
||||
build-nightly:
|
||||
executor: rust-nightly-executor
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: build --all-features
|
||||
command: cargo build --all-features
|
||||
|
||||
|
||||
workflows:
|
||||
version: 2
|
||||
build:
|
||||
jobs:
|
||||
- check-fmt
|
||||
- clippy
|
||||
- build-native:
|
||||
requires:
|
||||
- check-fmt
|
||||
- build-wasm:
|
||||
requires:
|
||||
- check-fmt
|
||||
- build-no-std:
|
||||
requires:
|
||||
- check-fmt
|
||||
- build-nightly:
|
||||
requires:
|
||||
- check-fmt
|
||||
- test-native:
|
||||
requires:
|
||||
- build-native
|
|
@ -0,0 +1,96 @@
|
|||
name: nalgebra CI build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ dev, master ]
|
||||
pull_request:
|
||||
branches: [ dev, master ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
check-fmt:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Check formatting
|
||||
run: cargo fmt -- --check
|
||||
clippy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install clippy
|
||||
run: rustup component add clippy
|
||||
- name: Run clippy
|
||||
run: cargo clippy
|
||||
build-nalgebra:
|
||||
runs-on: ubuntu-latest
|
||||
# env:
|
||||
# RUSTFLAGS: -D warnings
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Build --no-default-feature
|
||||
run: cargo build --no-default-features;
|
||||
- name: Build (default features)
|
||||
run: cargo build;
|
||||
- name: Build --all-features
|
||||
run: cargo build --all-features;
|
||||
- name: Build nalgebra-glm
|
||||
run: cargo build -p nalgebra-glm --all-features;
|
||||
- name: Build nalgebra-lapack
|
||||
run: cd nalgebra-lapack; cargo build;
|
||||
- name: Build nalgebra-sparse
|
||||
run: cd nalgebra-sparse; cargo build;
|
||||
test-nalgebra:
|
||||
runs-on: ubuntu-latest
|
||||
# env:
|
||||
# RUSTFLAGS: -D warnings
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: test
|
||||
run: cargo test --features arbitrary --features serde-serialize,abomonation-serialize,sparse,debug,io,compare,libm,proptest-support,slow-tests;
|
||||
test-nalgebra-glm:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: test nalgebra-glm
|
||||
run: cargo test -p nalgebra-glm --features arbitrary,serde-serialize,abomonation-serialize,sparse,debug,io,compare,libm,proptest-support,slow-tests;
|
||||
test-nalgebra-sparse:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: test nalgebra-sparse
|
||||
# Manifest-path is necessary because cargo otherwise won't correctly forward features
|
||||
# We increase number of proptest cases to hopefully catch more potential bugs
|
||||
run: PROPTEST_CASES=10000 cargo test --manifest-path=nalgebra-sparse/Cargo.toml --features compare,proptest-support
|
||||
- name: test nalgebra-sparse (slow tests)
|
||||
# Unfortunately, the "slow-tests" take so much time that we need to run them with --release
|
||||
run: PROPTEST_CASES=10000 cargo test --release --manifest-path=nalgebra-sparse/Cargo.toml --features compare,proptest-support,slow-tests slow
|
||||
build-wasm:
|
||||
runs-on: ubuntu-latest
|
||||
# env:
|
||||
# RUSTFLAGS: -D warnings
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- run: rustup target add wasm32-unknown-unknown
|
||||
- name: build nalgebra
|
||||
run: cargo build --verbose --target wasm32-unknown-unknown;
|
||||
- name: build nalgebra-glm
|
||||
run: cargo build -p nalgebra-glm --verbose --target wasm32-unknown-unknown;
|
||||
build-no-std:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Install latest nightly
|
||||
uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly
|
||||
override: true
|
||||
components: rustfmt
|
||||
- name: install xargo
|
||||
run: cp .github/Xargo.toml .; rustup component add rust-src; cargo install -f xargo;
|
||||
- name: build
|
||||
run: xargo build --verbose --no-default-features --target=x86_64-unknown-linux-gnu;
|
||||
- name: build --feature alloc
|
||||
run: xargo build --verbose --no-default-features --features alloc --target=x86_64-unknown-linux-gnu;
|
|
@ -10,3 +10,4 @@ Cargo.lock
|
|||
site/
|
||||
.vscode/
|
||||
.idea/
|
||||
proptest-regressions
|
23
CHANGELOG.md
23
CHANGELOG.md
|
@ -4,6 +4,27 @@ documented here.
|
|||
|
||||
This project adheres to [Semantic Versioning](https://semver.org/).
|
||||
|
||||
## [0.25.0]
|
||||
This updates all the dependencies of nalgebra to their latest version, including:
|
||||
- rand 0.8
|
||||
- proptest 1.0
|
||||
- simba 0.4
|
||||
|
||||
### New crate!
|
||||
Alongside this release of `nalgebra`, we are releasing `nalgebra-sparse`: a crate dedicated to sparse matrix
|
||||
computation with `nalgebra`. The `sparse` module of `nalgebra`itself still exists for backward compatibility
|
||||
but it will be deprecated soon in favor of the `nalgebra-sparse` crate.
|
||||
|
||||
### Added
|
||||
* Add `UnitDualQuaternion`, a dual-quaternion with unit magnitude which can be used as an isometry transformation.
|
||||
* Add `UDU::new()` and `matrix.udu()` to compute the UDU factorization of a matrix.
|
||||
* Add `ColPivQR::new()` and `matrix.col_piv_qr()` to compute the QR decomposition with column pivoting of a matrix.
|
||||
* Add `from_basis_unchecked` to all the rotation types. This builds a rotation from a set of basis vectors (representing the columns of the corresponding rotation matrix).
|
||||
* Add `Matrix::cap_magnitude` to cap the magnitude of a vector.
|
||||
* Add `UnitQuaternion::append_axisangle_linearized` to approximately append a rotation represented as an axis-angle to a rotation represented as an unit quaternion.
|
||||
* Mark the iterators on matrix components as `DoubleEndedIter`.
|
||||
* Re-export `simba::simd::SimdValue` at the root of the `nalgebra` crate.
|
||||
|
||||
## [0.24.0]
|
||||
|
||||
### Added
|
||||
|
@ -67,7 +88,7 @@ In this release, we are no longer relying on traits from the __alga__ crate for
|
|||
Instead, we use traits from the new [simba](https://crates.io/crates/simba) crate which are both
|
||||
simpler, and allow for significant optimizations like AoSoA SIMD.
|
||||
|
||||
Refer to the [monthly Rustsim blogpost](https://www.rustsim.org/blog/2020/04/01/this-month-in-rustsim/)
|
||||
Refer to the [monthly dimforge blogpost](https://www.dimforge.org/blog/2020/04/01/this-month-in-dimforge/)
|
||||
for details about this switch and its benefits.
|
||||
|
||||
### Added
|
||||
|
|
49
Cargo.toml
49
Cargo.toml
|
@ -1,28 +1,29 @@
|
|||
[package]
|
||||
name = "nalgebra"
|
||||
version = "0.24.0"
|
||||
version = "0.25.0"
|
||||
authors = [ "Sébastien Crozet <developer@crozet.re>" ]
|
||||
|
||||
description = "Linear algebra library with transformations and statically-sized or dynamically-sized matrices."
|
||||
documentation = "https://nalgebra.org/rustdoc/nalgebra/index.html"
|
||||
description = "General-purpose linear algebra library with transformations and statically-sized or dynamically-sized matrices."
|
||||
documentation = "https://www.nalgebra.org/docs"
|
||||
homepage = "https://nalgebra.org"
|
||||
repository = "https://github.com/rustsim/nalgebra"
|
||||
repository = "https://github.com/dimforge/nalgebra"
|
||||
readme = "README.md"
|
||||
categories = [ "science" ]
|
||||
categories = [ "science", "mathematics", "wasm", "no-std" ]
|
||||
keywords = [ "linear", "algebra", "matrix", "vector", "math" ]
|
||||
license = "Apache-2.0"
|
||||
license = "BSD-3-Clause"
|
||||
edition = "2018"
|
||||
|
||||
exclude = ["/ci/*", "/.travis.yml", "/Makefile"]
|
||||
|
||||
[badges]
|
||||
maintenance = { status = "actively-developed" }
|
||||
|
||||
[lib]
|
||||
name = "nalgebra"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[features]
|
||||
default = [ "std" ]
|
||||
std = [ "matrixmultiply", "rand/std", "rand_distr", "simba/std" ]
|
||||
stdweb = [ "rand/stdweb" ]
|
||||
std = [ "matrixmultiply", "rand/std", "rand/std_rng", "rand_distr", "simba/std" ]
|
||||
arbitrary = [ "quickcheck" ]
|
||||
serde-serialize = [ "serde", "num-complex/serde" ]
|
||||
abomonation-serialize = [ "abomonation" ]
|
||||
|
@ -33,32 +34,39 @@ io = [ "pest", "pest_derive" ]
|
|||
compare = [ "matrixcompare-core" ]
|
||||
libm = [ "simba/libm" ]
|
||||
libm-force = [ "simba/libm_force" ]
|
||||
proptest-support = [ "proptest" ]
|
||||
no_unsound_assume_init = [ ]
|
||||
|
||||
# This feature is only used for tests, and enables tests that require more time to run
|
||||
slow-tests = []
|
||||
|
||||
[dependencies]
|
||||
typenum = "1.12"
|
||||
generic-array = "0.14"
|
||||
rand = { version = "0.7", default-features = false }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
getrandom = { version = "0.2", default-features = false, features = [ "js" ] } # For wasm
|
||||
num-traits = { version = "0.2", default-features = false }
|
||||
num-complex = { version = "0.3", default-features = false }
|
||||
num-rational = { version = "0.3", default-features = false }
|
||||
approx = { version = "0.4", default-features = false }
|
||||
simba = { version = "0.3", default-features = false }
|
||||
simba = { version = "0.4", default-features = false }
|
||||
alga = { version = "0.9", default-features = false, optional = true }
|
||||
rand_distr = { version = "0.3", optional = true }
|
||||
matrixmultiply = { version = "0.2", optional = true }
|
||||
rand_distr = { version = "0.4", default-features = false, optional = true }
|
||||
matrixmultiply = { version = "0.3", optional = true }
|
||||
serde = { version = "1.0", default-features = false, features = [ "derive" ], optional = true }
|
||||
abomonation = { version = "0.7", optional = true }
|
||||
mint = { version = "0.5", optional = true }
|
||||
quickcheck = { version = "0.9", optional = true }
|
||||
quickcheck = { version = "1", optional = true }
|
||||
pest = { version = "2", optional = true }
|
||||
pest_derive = { version = "2", optional = true }
|
||||
bytemuck = { version = "1.5", optional = true }
|
||||
matrixcompare-core = { version = "0.1", optional = true }
|
||||
proptest = { version = "1", optional = true, default-features = false, features = ["std"] }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = "1.0"
|
||||
rand_xorshift = "0.2"
|
||||
rand_isaac = "0.2"
|
||||
rand_xorshift = "0.3"
|
||||
rand_isaac = "0.3"
|
||||
### Uncomment this line before running benchmarks.
|
||||
### We can't just let this uncommented because that would break
|
||||
### compilation for #[no-std] because of the terrible Cargo bug
|
||||
|
@ -66,10 +74,11 @@ rand_isaac = "0.2"
|
|||
#criterion = "0.2.10"
|
||||
|
||||
# For matrix comparison macro
|
||||
matrixcompare = "0.1.3"
|
||||
matrixcompare = "0.2.0"
|
||||
itertools = "0.10"
|
||||
|
||||
[workspace]
|
||||
members = [ "nalgebra-lapack", "nalgebra-glm" ]
|
||||
members = [ "nalgebra-lapack", "nalgebra-glm", "nalgebra-sparse" ]
|
||||
|
||||
[[bench]]
|
||||
name = "nalgebra_bench"
|
||||
|
@ -78,3 +87,7 @@ path = "benches/lib.rs"
|
|||
|
||||
[profile.bench]
|
||||
lto = true
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
# Enable certain features when building docs for docs.rs
|
||||
features = [ "proptest-support", "compare" ]
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
</p>
|
||||
<p align = "center">
|
||||
<strong>
|
||||
<a href="https://nalgebra.org">Users guide</a> | <a href="https://nalgebra.org/rustdoc/nalgebra/index.html">Documentation</a> | <a href="https://discourse.nphysics.org/c/nalgebra">Forum</a>
|
||||
<a href="https://nalgebra.org">Users guide</a> | <a href="https://docs.rs/nalgebra/latest/nalgebra/">Documentation</a> | <a href="https://discourse.nphysics.org/c/nalgebra">Forum</a>
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
|
|
|
@ -136,6 +136,30 @@ fn mat500_mul_mat500(bench: &mut criterion::Criterion) {
|
|||
bench.bench_function("mat500_mul_mat500", move |bh| bh.iter(|| &a * &b));
|
||||
}
|
||||
|
||||
fn iter(bench: &mut criterion::Criterion) {
|
||||
let a = DMatrix::<f64>::new_random(1000, 1000);
|
||||
|
||||
bench.bench_function("iter", move |bh| {
|
||||
bh.iter(|| {
|
||||
for value in a.iter() {
|
||||
criterion::black_box(value);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn iter_rev(bench: &mut criterion::Criterion) {
|
||||
let a = DMatrix::<f64>::new_random(1000, 1000);
|
||||
|
||||
bench.bench_function("iter_rev", move |bh| {
|
||||
bh.iter(|| {
|
||||
for value in a.iter().rev() {
|
||||
criterion::black_box(value);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
fn copy_from(bench: &mut criterion::Criterion) {
|
||||
let a = DMatrix::<f64>::new_random(1000, 1000);
|
||||
let mut b = DMatrix::<f64>::new_random(1000, 1000);
|
||||
|
@ -235,6 +259,8 @@ criterion_group!(
|
|||
mat10_mul_mat10_static,
|
||||
mat100_mul_mat100,
|
||||
mat500_mul_mat500,
|
||||
iter,
|
||||
iter_rev,
|
||||
copy_from,
|
||||
axpy,
|
||||
tr_mul_to,
|
||||
|
|
|
@ -1,22 +1,24 @@
|
|||
[package]
|
||||
name = "nalgebra-glm"
|
||||
version = "0.10.0"
|
||||
version = "0.11.0"
|
||||
authors = ["sebcrozet <developer@crozet.re>"]
|
||||
|
||||
description = "A computer-graphics oriented API for nalgebra, inspired by the C++ GLM library."
|
||||
documentation = "https://www.nalgebra.org/rustdoc_glm/nalgebra_glm/index.html"
|
||||
documentation = "https://www.nalgebra.org/docs"
|
||||
homepage = "https://nalgebra.org"
|
||||
repository = "https://github.com/rustsim/nalgebra"
|
||||
repository = "https://github.com/dimforge/nalgebra"
|
||||
readme = "../README.md"
|
||||
categories = [ "science" ]
|
||||
categories = [ "science", "mathematics", "wasm", "no standard library" ]
|
||||
keywords = [ "linear", "algebra", "matrix", "vector", "math" ]
|
||||
license = "BSD-3-Clause"
|
||||
edition = "2018"
|
||||
|
||||
[badges]
|
||||
maintenance = { status = "actively-developed" }
|
||||
|
||||
[features]
|
||||
default = [ "std" ]
|
||||
std = [ "nalgebra/std", "simba/std" ]
|
||||
stdweb = [ "nalgebra/stdweb" ]
|
||||
arbitrary = [ "nalgebra/arbitrary" ]
|
||||
serde-serialize = [ "nalgebra/serde-serialize" ]
|
||||
abomonation-serialize = [ "nalgebra/abomonation-serialize" ]
|
||||
|
@ -24,5 +26,5 @@ abomonation-serialize = [ "nalgebra/abomonation-serialize" ]
|
|||
[dependencies]
|
||||
num-traits = { version = "0.2", default-features = false }
|
||||
approx = { version = "0.4", default-features = false }
|
||||
simba = { version = "0.3", default-features = false }
|
||||
nalgebra = { path = "..", version = "0.24", default-features = false }
|
||||
simba = { version = "0.4", default-features = false }
|
||||
nalgebra = { path = "..", version = "0.25", default-features = false }
|
||||
|
|
|
@ -1,40 +1,47 @@
|
|||
[package]
|
||||
name = "nalgebra-lapack"
|
||||
version = "0.15.0"
|
||||
version = "0.16.0"
|
||||
authors = [ "Sébastien Crozet <developer@crozet.re>", "Andrew Straw <strawman@astraw.com>" ]
|
||||
|
||||
description = "Linear algebra library with transformations and satically-sized or dynamically-sized matrices."
|
||||
documentation = "https://nalgebra.org/doc/nalgebra/index.html"
|
||||
description = "Matrix decompositions using nalgebra matrices and Lapack bindings."
|
||||
documentation = "https://www.nalgebra.org/docs"
|
||||
homepage = "https://nalgebra.org"
|
||||
repository = "https://github.com/rustsim/nalgebra"
|
||||
readme = "README.md"
|
||||
keywords = [ "linear", "algebra", "matrix", "vector" ]
|
||||
repository = "https://github.com/dimforge/nalgebra"
|
||||
readme = "../README.md"
|
||||
categories = [ "science", "mathematics" ]
|
||||
keywords = [ "linear", "algebra", "matrix", "vector", "math", "lapack" ]
|
||||
license = "BSD-3-Clause"
|
||||
edition = "2018"
|
||||
|
||||
[badges]
|
||||
maintenance = { status = "actively-developed" }
|
||||
|
||||
[features]
|
||||
serde-serialize = [ "serde", "serde_derive" ]
|
||||
proptest-support = [ "nalgebra/proptest-support" ]
|
||||
arbitrary = [ "nalgebra/arbitrary" ]
|
||||
|
||||
# For BLAS/LAPACK
|
||||
default = ["openblas"]
|
||||
default = ["netlib"]
|
||||
openblas = ["lapack-src/openblas"]
|
||||
netlib = ["lapack-src/netlib"]
|
||||
accelerate = ["lapack-src/accelerate"]
|
||||
intel-mkl = ["lapack-src/intel-mkl"]
|
||||
|
||||
[dependencies]
|
||||
nalgebra = { version = "0.24", path = ".." }
|
||||
nalgebra = { version = "0.25", path = ".." }
|
||||
num-traits = "0.2"
|
||||
num-complex = { version = "0.2", default-features = false }
|
||||
simba = "0.2"
|
||||
num-complex = { version = "0.3", default-features = false }
|
||||
simba = "0.4"
|
||||
serde = { version = "1.0", optional = true }
|
||||
serde_derive = { version = "1.0", optional = true }
|
||||
lapack = { version = "0.16", default-features = false }
|
||||
lapack-src = { version = "0.5", default-features = false }
|
||||
lapack = { version = "0.17", default-features = false }
|
||||
lapack-src = { version = "0.6", default-features = false }
|
||||
# clippy = "*"
|
||||
|
||||
[dev-dependencies]
|
||||
nalgebra = { version = "0.24", features = [ "arbitrary" ], path = ".." }
|
||||
quickcheck = "0.9"
|
||||
approx = "0.3"
|
||||
rand = "0.7"
|
||||
nalgebra = { version = "0.25", features = [ "arbitrary" ], path = ".." }
|
||||
proptest = { version = "1", default-features = false, features = ["std"] }
|
||||
quickcheck = "1"
|
||||
approx = "0.4"
|
||||
rand = "0.8"
|
||||
|
|
|
@ -78,9 +78,9 @@ where
|
|||
|
||||
let lda = n as i32;
|
||||
|
||||
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
// TODO: Tap into the workspace.
|
||||
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
|
||||
let mut info = 0;
|
||||
let mut placeholder1 = [N::zero()];
|
||||
|
@ -107,8 +107,10 @@ where
|
|||
|
||||
match (left_eigenvectors, eigenvectors) {
|
||||
(true, true) => {
|
||||
let mut vl = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut vr = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut vl =
|
||||
unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
|
||||
let mut vr =
|
||||
unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
|
||||
|
||||
N::xgeev(
|
||||
ljob,
|
||||
|
@ -137,7 +139,8 @@ where
|
|||
}
|
||||
}
|
||||
(true, false) => {
|
||||
let mut vl = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut vl =
|
||||
unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
|
||||
|
||||
N::xgeev(
|
||||
ljob,
|
||||
|
@ -166,7 +169,8 @@ where
|
|||
}
|
||||
}
|
||||
(false, true) => {
|
||||
let mut vr = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut vr =
|
||||
unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
|
||||
|
||||
N::xgeev(
|
||||
ljob,
|
||||
|
@ -243,8 +247,8 @@ where
|
|||
|
||||
let lda = n as i32;
|
||||
|
||||
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
|
||||
let mut info = 0;
|
||||
let mut placeholder1 = [N::zero()];
|
||||
|
@ -287,7 +291,7 @@ where
|
|||
);
|
||||
lapack_panic!(info);
|
||||
|
||||
let mut res = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut res = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
|
||||
for i in 0..res.len() {
|
||||
res[i] = Complex::new(wr[i], wi[i]);
|
||||
|
|
|
@ -60,7 +60,7 @@ where
|
|||
"Unable to compute the hessenberg decomposition of an empty matrix."
|
||||
);
|
||||
|
||||
let mut tau = unsafe { Matrix::new_uninitialized_generic(nrows.sub(U1), U1) };
|
||||
let mut tau = unsafe { Matrix::new_uninitialized_generic(nrows.sub(U1), U1).assume_init() };
|
||||
|
||||
let mut info = 0;
|
||||
let lwork =
|
||||
|
|
|
@ -57,7 +57,8 @@ where
|
|||
let (nrows, ncols) = m.data.shape();
|
||||
|
||||
let mut info = 0;
|
||||
let mut tau = unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1) };
|
||||
let mut tau =
|
||||
unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1).assume_init() };
|
||||
|
||||
if nrows.value() == 0 || ncols.value() == 0 {
|
||||
return Self { qr: m, tau: tau };
|
||||
|
|
|
@ -78,9 +78,9 @@ where
|
|||
|
||||
let mut info = 0;
|
||||
|
||||
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut q = unsafe { Matrix::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut wr = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
let mut wi = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
let mut q = unsafe { Matrix::new_uninitialized_generic(nrows, ncols).assume_init() };
|
||||
// Placeholders:
|
||||
let mut bwork = [0i32];
|
||||
let mut unused = 0;
|
||||
|
@ -151,7 +151,8 @@ where
|
|||
where
|
||||
DefaultAllocator: Allocator<Complex<N>, D>,
|
||||
{
|
||||
let mut out = unsafe { VectorN::new_uninitialized_generic(self.t.data.shape().0, U1) };
|
||||
let mut out =
|
||||
unsafe { VectorN::new_uninitialized_generic(self.t.data.shape().0, U1).assume_init() };
|
||||
|
||||
for i in 0..out.len() {
|
||||
out[i] = Complex::new(self.re[i], self.im[i])
|
||||
|
|
|
@ -99,9 +99,9 @@ macro_rules! svd_impl(
|
|||
|
||||
let lda = nrows.value() as i32;
|
||||
|
||||
let mut u = unsafe { Matrix::new_uninitialized_generic(nrows, nrows) };
|
||||
let mut s = unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1) };
|
||||
let mut vt = unsafe { Matrix::new_uninitialized_generic(ncols, ncols) };
|
||||
let mut u = unsafe { Matrix::new_uninitialized_generic(nrows, nrows).assume_init() };
|
||||
let mut s = unsafe { Matrix::new_uninitialized_generic(nrows.min(ncols), U1).assume_init() };
|
||||
let mut vt = unsafe { Matrix::new_uninitialized_generic(ncols, ncols).assume_init() };
|
||||
|
||||
let ldu = nrows.value();
|
||||
let ldvt = ncols.value();
|
||||
|
|
|
@ -94,7 +94,7 @@ where
|
|||
|
||||
let lda = n as i32;
|
||||
|
||||
let mut values = unsafe { Matrix::new_uninitialized_generic(nrows, U1) };
|
||||
let mut values = unsafe { Matrix::new_uninitialized_generic(nrows, U1).assume_init() };
|
||||
let mut info = 0;
|
||||
|
||||
let lwork = N::xsyev_work_size(jobz, b'L', n as i32, m.as_mut_slice(), lda, &mut info);
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
#[macro_use]
|
||||
extern crate approx;
|
||||
#[cfg(not(feature = "proptest-support"))]
|
||||
compile_error!("Tests must be run with `proptest-support`");
|
||||
|
||||
extern crate nalgebra as na;
|
||||
extern crate nalgebra_lapack as nl;
|
||||
#[macro_use]
|
||||
extern crate quickcheck;
|
||||
|
||||
extern crate lapack;
|
||||
extern crate lapack_src;
|
||||
|
||||
mod linalg;
|
||||
#[path = "../../tests/proptest/mod.rs"]
|
||||
mod proptest;
|
||||
|
|
|
@ -1,37 +1,36 @@
|
|||
use std::cmp;
|
||||
|
||||
use na::{DMatrix, DVector, Matrix3, Matrix4, Matrix4x3, Vector4};
|
||||
use na::{DMatrix, DVector, Matrix4x3, Vector4};
|
||||
use nl::Cholesky;
|
||||
|
||||
quickcheck! {
|
||||
fn cholesky(m: DMatrix<f64>) -> bool {
|
||||
if m.len() != 0 {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn cholesky(m in dmatrix()) {
|
||||
let m = &m * m.transpose();
|
||||
if let Some(chol) = Cholesky::new(m.clone()) {
|
||||
let l = chol.unpack();
|
||||
let reconstructed_m = &l * l.transpose();
|
||||
|
||||
return relative_eq!(reconstructed_m, m, epsilon = 1.0e-7)
|
||||
prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7));
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
fn cholesky_static(m: Matrix3<f64>) -> bool {
|
||||
#[test]
|
||||
fn cholesky_static(m in matrix3()) {
|
||||
let m = &m * m.transpose();
|
||||
if let Some(chol) = Cholesky::new(m) {
|
||||
let l = chol.unpack();
|
||||
let reconstructed_m = &l * l.transpose();
|
||||
|
||||
relative_eq!(reconstructed_m, m, epsilon = 1.0e-7)
|
||||
}
|
||||
else {
|
||||
false
|
||||
prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7))
|
||||
}
|
||||
}
|
||||
|
||||
fn cholesky_solve(n: usize, nb: usize) -> bool {
|
||||
if n != 0 {
|
||||
#[test]
|
||||
fn cholesky_solve(n in PROPTEST_MATRIX_DIM, nb in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::min(n, 15); // To avoid slowing down the test too much.
|
||||
let nb = cmp::min(nb, 15); // To avoid slowing down the test too much.
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
|
@ -44,33 +43,28 @@ quickcheck! {
|
|||
let sol1 = chol.solve(&b1).unwrap();
|
||||
let sol2 = chol.solve(&b2).unwrap();
|
||||
|
||||
return relative_eq!(&m * sol1, b1, epsilon = 1.0e-6) &&
|
||||
relative_eq!(&m * sol2, b2, epsilon = 1.0e-6)
|
||||
prop_assert!(relative_eq!(&m * sol1, b1, epsilon = 1.0e-6));
|
||||
prop_assert!(relative_eq!(&m * sol2, b2, epsilon = 1.0e-6));
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn cholesky_solve_static(m: Matrix4<f64>) -> bool {
|
||||
#[test]
|
||||
fn cholesky_solve_static(m in matrix4()) {
|
||||
let m = &m * m.transpose();
|
||||
match Cholesky::new(m) {
|
||||
Some(chol) => {
|
||||
if let Some(chol) = Cholesky::new(m) {
|
||||
let b1 = Vector4::new_random();
|
||||
let b2 = Matrix4x3::new_random();
|
||||
|
||||
let sol1 = chol.solve(&b1).unwrap();
|
||||
let sol2 = chol.solve(&b2).unwrap();
|
||||
|
||||
relative_eq!(m * sol1, b1, epsilon = 1.0e-7) &&
|
||||
relative_eq!(m * sol2, b2, epsilon = 1.0e-7)
|
||||
},
|
||||
None => true
|
||||
prop_assert!(relative_eq!(m * sol1, b1, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(m * sol2, b2, epsilon = 1.0e-7));
|
||||
}
|
||||
}
|
||||
|
||||
fn cholesky_inverse(n: usize) -> bool {
|
||||
if n != 0 {
|
||||
#[test]
|
||||
fn cholesky_inverse(n in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::min(n, 15); // To avoid slowing down the test too much.
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
let m = &m * m.transpose();
|
||||
|
@ -79,23 +73,18 @@ quickcheck! {
|
|||
let id1 = &m * &m1;
|
||||
let id2 = &m1 * &m;
|
||||
|
||||
return id1.is_identity(1.0e-6) && id2.is_identity(1.0e-6);
|
||||
prop_assert!(id1.is_identity(1.0e-6) && id2.is_identity(1.0e-6));
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn cholesky_inverse_static(m: Matrix4<f64>) -> bool {
|
||||
#[test]
|
||||
fn cholesky_inverse_static(m in matrix4()) {
|
||||
let m = m * m.transpose();
|
||||
match Cholesky::new(m.clone()).unwrap().inverse() {
|
||||
Some(m1) => {
|
||||
if let Some(m1) = Cholesky::new(m.clone()).unwrap().inverse() {
|
||||
let id1 = &m * &m1;
|
||||
let id2 = &m1 * &m;
|
||||
|
||||
id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5)
|
||||
},
|
||||
None => true
|
||||
prop_assert!(id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,38 +1,32 @@
|
|||
use std::cmp;
|
||||
|
||||
use nl::Hessenberg;
|
||||
use na::{DMatrix, Matrix4};
|
||||
use nl::Hessenberg;
|
||||
|
||||
quickcheck!{
|
||||
fn hessenberg(n: usize) -> bool {
|
||||
if n != 0 {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn hessenberg(n in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::min(n, 25);
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
|
||||
match Hessenberg::new(m.clone()) {
|
||||
Some(hess) => {
|
||||
if let Some(hess) = Hessenberg::new(m.clone()) {
|
||||
let h = hess.h();
|
||||
let p = hess.p();
|
||||
|
||||
relative_eq!(m, &p * h * p.transpose(), epsilon = 1.0e-7)
|
||||
},
|
||||
None => true
|
||||
}
|
||||
}
|
||||
else {
|
||||
true
|
||||
prop_assert!(relative_eq!(m, &p * h * p.transpose(), epsilon = 1.0e-7))
|
||||
}
|
||||
}
|
||||
|
||||
fn hessenberg_static(m: Matrix4<f64>) -> bool {
|
||||
match Hessenberg::new(m) {
|
||||
Some(hess) => {
|
||||
#[test]
|
||||
fn hessenberg_static(m in matrix4()) {
|
||||
if let Some(hess) = Hessenberg::new(m) {
|
||||
let h = hess.h();
|
||||
let p = hess.p();
|
||||
|
||||
relative_eq!(m, p * h * p.transpose(), epsilon = 1.0e-7)
|
||||
},
|
||||
None => true
|
||||
prop_assert!(relative_eq!(m, p * h * p.transpose(), epsilon = 1.0e-7))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
use std::cmp;
|
||||
|
||||
use na::{DMatrix, DVector, Matrix3x4, Matrix4, Matrix4x3, Vector4};
|
||||
use na::{DMatrix, DVector, Matrix4x3, Vector4};
|
||||
use nl::LU;
|
||||
|
||||
quickcheck! {
|
||||
fn lup(m: DMatrix<f64>) -> bool {
|
||||
if m.len() != 0 {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn lup(m in dmatrix()) {
|
||||
let lup = LU::new(m.clone());
|
||||
let l = lup.l();
|
||||
let u = lup.u();
|
||||
|
@ -14,15 +17,12 @@ quickcheck! {
|
|||
|
||||
let computed2 = lup.p() * l * u;
|
||||
|
||||
relative_eq!(computed1, m, epsilon = 1.0e-7) &&
|
||||
relative_eq!(computed2, m, epsilon = 1.0e-7)
|
||||
}
|
||||
else {
|
||||
true
|
||||
}
|
||||
prop_assert!(relative_eq!(computed1, m, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(computed2, m, epsilon = 1.0e-7));
|
||||
}
|
||||
|
||||
fn lu_static(m: Matrix3x4<f64>) -> bool {
|
||||
#[test]
|
||||
fn lu_static(m in matrix3x5()) {
|
||||
let lup = LU::new(m);
|
||||
let l = lup.l();
|
||||
let u = lup.u();
|
||||
|
@ -31,12 +31,12 @@ quickcheck! {
|
|||
|
||||
let computed2 = lup.p() * l * u;
|
||||
|
||||
relative_eq!(computed1, m, epsilon = 1.0e-7) &&
|
||||
relative_eq!(computed2, m, epsilon = 1.0e-7)
|
||||
prop_assert!(relative_eq!(computed1, m, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(computed2, m, epsilon = 1.0e-7));
|
||||
}
|
||||
|
||||
fn lu_solve(n: usize, nb: usize) -> bool {
|
||||
if n != 0 {
|
||||
#[test]
|
||||
fn lu_solve(n in PROPTEST_MATRIX_DIM, nb in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::min(n, 25); // To avoid slowing down the test too much.
|
||||
let nb = cmp::min(nb, 25); // To avoid slowing down the test too much.
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
|
@ -51,17 +51,14 @@ quickcheck! {
|
|||
let tr_sol1 = lup.solve_transpose(&b1).unwrap();
|
||||
let tr_sol2 = lup.solve_transpose(&b2).unwrap();
|
||||
|
||||
relative_eq!(&m * sol1, b1, epsilon = 1.0e-7) &&
|
||||
relative_eq!(&m * sol2, b2, epsilon = 1.0e-7) &&
|
||||
relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7) &&
|
||||
relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7)
|
||||
}
|
||||
else {
|
||||
true
|
||||
}
|
||||
prop_assert!(relative_eq!(&m * sol1, b1, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(&m * sol2, b2, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7));
|
||||
}
|
||||
|
||||
fn lu_solve_static(m: Matrix4<f64>) -> bool {
|
||||
#[test]
|
||||
fn lu_solve_static(m in matrix4()) {
|
||||
let lup = LU::new(m);
|
||||
let b1 = Vector4::new_random();
|
||||
let b2 = Matrix4x3::new_random();
|
||||
|
@ -71,14 +68,14 @@ quickcheck! {
|
|||
let tr_sol1 = lup.solve_transpose(&b1).unwrap();
|
||||
let tr_sol2 = lup.solve_transpose(&b2).unwrap();
|
||||
|
||||
relative_eq!(m * sol1, b1, epsilon = 1.0e-7) &&
|
||||
relative_eq!(m * sol2, b2, epsilon = 1.0e-7) &&
|
||||
relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7) &&
|
||||
relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7)
|
||||
prop_assert!(relative_eq!(m * sol1, b1, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(m * sol2, b2, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(m.transpose() * tr_sol1, b1, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(m.transpose() * tr_sol2, b2, epsilon = 1.0e-7));
|
||||
}
|
||||
|
||||
fn lu_inverse(n: usize) -> bool {
|
||||
if n != 0 {
|
||||
#[test]
|
||||
fn lu_inverse(n in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::min(n, 25); // To avoid slowing down the test too much.
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
|
||||
|
@ -86,22 +83,17 @@ quickcheck! {
|
|||
let id1 = &m * &m1;
|
||||
let id2 = &m1 * &m;
|
||||
|
||||
return id1.is_identity(1.0e-7) && id2.is_identity(1.0e-7);
|
||||
prop_assert!(id1.is_identity(1.0e-7) && id2.is_identity(1.0e-7));
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn lu_inverse_static(m: Matrix4<f64>) -> bool {
|
||||
match LU::new(m.clone()).inverse() {
|
||||
Some(m1) => {
|
||||
#[test]
|
||||
fn lu_inverse_static(m in matrix4()) {
|
||||
if let Some(m1) = LU::new(m.clone()).inverse() {
|
||||
let id1 = &m * &m1;
|
||||
let id2 = &m1 * &m;
|
||||
|
||||
id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5)
|
||||
},
|
||||
None => true
|
||||
prop_assert!(id1.is_identity(1.0e-5) && id2.is_identity(1.0e-5))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
use na::{DMatrix, Matrix4x3};
|
||||
use nl::QR;
|
||||
|
||||
quickcheck! {
|
||||
fn qr(m: DMatrix<f64>) -> bool {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn qr(m in dmatrix()) {
|
||||
let qr = QR::new(m.clone());
|
||||
let q = qr.q();
|
||||
let r = qr.r();
|
||||
|
||||
relative_eq!(m, q * r, epsilon = 1.0e-7)
|
||||
prop_assert!(relative_eq!(m, q * r, epsilon = 1.0e-7))
|
||||
}
|
||||
|
||||
fn qr_static(m: Matrix4x3<f64>) -> bool {
|
||||
#[test]
|
||||
fn qr_static(m in matrix5x3()) {
|
||||
let qr = QR::new(m);
|
||||
let q = qr.q();
|
||||
let r = qr.r();
|
||||
|
||||
relative_eq!(m, q * r, epsilon = 1.0e-7)
|
||||
prop_assert!(relative_eq!(m, q * r, epsilon = 1.0e-7))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,14 +3,16 @@ use std::cmp;
|
|||
use na::{DMatrix, Matrix4};
|
||||
use nl::Eigen;
|
||||
|
||||
quickcheck! {
|
||||
fn eigensystem(n: usize) -> bool {
|
||||
if n != 0 {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn eigensystem(n in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::min(n, 25);
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
|
||||
match Eigen::new(m.clone(), true, true) {
|
||||
Some(eig) => {
|
||||
if let Some(eig) = Eigen::new(m.clone(), true, true) {
|
||||
let eigvals = DMatrix::from_diagonal(&eig.eigenvalues);
|
||||
let transformed_eigvectors = &m * eig.eigenvectors.as_ref().unwrap();
|
||||
let scaled_eigvectors = eig.eigenvectors.as_ref().unwrap() * &eigvals;
|
||||
|
@ -18,20 +20,14 @@ quickcheck! {
|
|||
let transformed_left_eigvectors = m.transpose() * eig.left_eigenvectors.as_ref().unwrap();
|
||||
let scaled_left_eigvectors = eig.left_eigenvectors.as_ref().unwrap() * &eigvals;
|
||||
|
||||
relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7) &&
|
||||
relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7)
|
||||
},
|
||||
None => true
|
||||
}
|
||||
}
|
||||
else {
|
||||
true
|
||||
prop_assert!(relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7));
|
||||
}
|
||||
}
|
||||
|
||||
fn eigensystem_static(m: Matrix4<f64>) -> bool {
|
||||
match Eigen::new(m, true, true) {
|
||||
Some(eig) => {
|
||||
#[test]
|
||||
fn eigensystem_static(m in matrix4()) {
|
||||
if let Some(eig) = Eigen::new(m, true, true) {
|
||||
let eigvals = Matrix4::from_diagonal(&eig.eigenvalues);
|
||||
let transformed_eigvectors = m * eig.eigenvectors.unwrap();
|
||||
let scaled_eigvectors = eig.eigenvectors.unwrap() * eigvals;
|
||||
|
@ -39,10 +35,8 @@ quickcheck! {
|
|||
let transformed_left_eigvectors = m.transpose() * eig.left_eigenvectors.unwrap();
|
||||
let scaled_left_eigvectors = eig.left_eigenvectors.unwrap() * eigvals;
|
||||
|
||||
relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7) &&
|
||||
relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7)
|
||||
},
|
||||
None => true
|
||||
prop_assert!(relative_eq!(transformed_eigvectors, scaled_eigvectors, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(transformed_left_eigvectors, scaled_left_eigvectors, epsilon = 1.0e-7));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
use na::{DMatrix, Matrix4};
|
||||
use na::DMatrix;
|
||||
use nl::Schur;
|
||||
use std::cmp;
|
||||
|
||||
quickcheck! {
|
||||
fn schur(n: usize) -> bool {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn schur(n in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::max(1, cmp::min(n, 10));
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
|
||||
let (vecs, vals) = Schur::new(m.clone()).unpack();
|
||||
|
||||
relative_eq!(&vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7)
|
||||
prop_assert!(relative_eq!(&vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7))
|
||||
}
|
||||
|
||||
fn schur_static(m: Matrix4<f64>) -> bool {
|
||||
#[test]
|
||||
fn schur_static(m in matrix4()) {
|
||||
let (vecs, vals) = Schur::new(m.clone()).unpack();
|
||||
|
||||
relative_eq!(vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7)
|
||||
prop_assert!(relative_eq!(vecs * vals * vecs.transpose(), m, epsilon = 1.0e-7))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,57 +1,53 @@
|
|||
use na::{DMatrix, Matrix3x4};
|
||||
use na::{DMatrix, Matrix3x5};
|
||||
use nl::SVD;
|
||||
|
||||
quickcheck! {
|
||||
fn svd(m: DMatrix<f64>) -> bool {
|
||||
if m.nrows() != 0 && m.ncols() != 0 {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn svd(m in dmatrix()) {
|
||||
let svd = SVD::new(m.clone()).unwrap();
|
||||
let sm = DMatrix::from_partial_diagonal(m.nrows(), m.ncols(), svd.singular_values.as_slice());
|
||||
|
||||
let reconstructed_m = &svd.u * sm * &svd.vt;
|
||||
let reconstructed_m2 = svd.recompose();
|
||||
|
||||
relative_eq!(reconstructed_m, m, epsilon = 1.0e-7) &&
|
||||
relative_eq!(reconstructed_m2, reconstructed_m, epsilon = 1.0e-7)
|
||||
}
|
||||
else {
|
||||
true
|
||||
}
|
||||
prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(reconstructed_m2, reconstructed_m, epsilon = 1.0e-7));
|
||||
}
|
||||
|
||||
fn svd_static(m: Matrix3x4<f64>) -> bool {
|
||||
#[test]
|
||||
fn svd_static(m in matrix3x5()) {
|
||||
let svd = SVD::new(m).unwrap();
|
||||
let sm = Matrix3x4::from_partial_diagonal(svd.singular_values.as_slice());
|
||||
let sm = Matrix3x5::from_partial_diagonal(svd.singular_values.as_slice());
|
||||
|
||||
let reconstructed_m = &svd.u * &sm * &svd.vt;
|
||||
let reconstructed_m2 = svd.recompose();
|
||||
|
||||
relative_eq!(reconstructed_m, m, epsilon = 1.0e-7) &&
|
||||
relative_eq!(reconstructed_m2, m, epsilon = 1.0e-7)
|
||||
}
|
||||
|
||||
fn pseudo_inverse(m: DMatrix<f64>) -> bool {
|
||||
if m.nrows() == 0 || m.ncols() == 0 {
|
||||
return true;
|
||||
prop_assert!(relative_eq!(reconstructed_m, m, epsilon = 1.0e-7));
|
||||
prop_assert!(relative_eq!(reconstructed_m2, m, epsilon = 1.0e-7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pseudo_inverse(m in dmatrix()) {
|
||||
let svd = SVD::new(m.clone()).unwrap();
|
||||
let im = svd.pseudo_inverse(1.0e-7);
|
||||
|
||||
if m.nrows() <= m.ncols() {
|
||||
return (&m * &im).is_identity(1.0e-7)
|
||||
prop_assert!((&m * &im).is_identity(1.0e-7));
|
||||
}
|
||||
|
||||
if m.nrows() >= m.ncols() {
|
||||
return (im * m).is_identity(1.0e-7)
|
||||
prop_assert!((im * m).is_identity(1.0e-7));
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn pseudo_inverse_static(m: Matrix3x4<f64>) -> bool {
|
||||
#[test]
|
||||
fn pseudo_inverse_static(m in matrix3x5()) {
|
||||
let svd = SVD::new(m).unwrap();
|
||||
let im = svd.pseudo_inverse(1.0e-7);
|
||||
|
||||
(m * im).is_identity(1.0e-7)
|
||||
prop_assert!((m * im).is_identity(1.0e-7))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,20 +1,25 @@
|
|||
use std::cmp;
|
||||
|
||||
use na::{DMatrix, Matrix4};
|
||||
use na::DMatrix;
|
||||
use nl::SymmetricEigen;
|
||||
|
||||
quickcheck! {
|
||||
fn symmetric_eigen(n: usize) -> bool {
|
||||
use crate::proptest::*;
|
||||
use proptest::{prop_assert, proptest};
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn symmetric_eigen(n in PROPTEST_MATRIX_DIM) {
|
||||
let n = cmp::max(1, cmp::min(n, 10));
|
||||
let m = DMatrix::<f64>::new_random(n, n);
|
||||
let eig = SymmetricEigen::new(m.clone());
|
||||
let recomp = eig.recompose();
|
||||
relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5)
|
||||
prop_assert!(relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5))
|
||||
}
|
||||
|
||||
fn symmetric_eigen_static(m: Matrix4<f64>) -> bool {
|
||||
#[test]
|
||||
fn symmetric_eigen_static(m in matrix4()) {
|
||||
let eig = SymmetricEigen::new(m);
|
||||
let recomp = eig.recompose();
|
||||
relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5)
|
||||
prop_assert!(relative_eq!(m.lower_triangle(), recomp.lower_triangle(), epsilon = 1.0e-5))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "nalgebra-sparse"
|
||||
version = "0.1.0"
|
||||
authors = [ "Andreas Longva", "Sébastien Crozet <developer@crozet.re>" ]
|
||||
edition = "2018"
|
||||
|
||||
[features]
|
||||
proptest-support = ["proptest", "nalgebra/proptest-support"]
|
||||
compare = [ "matrixcompare-core" ]
|
||||
|
||||
# Enable to enable running some tests that take a lot of time to run
|
||||
slow-tests = []
|
||||
|
||||
[dependencies]
|
||||
nalgebra = { version="0.25", path = "../" }
|
||||
num-traits = { version = "0.2", default-features = false }
|
||||
proptest = { version = "1.0", optional = true }
|
||||
matrixcompare-core = { version = "0.1.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
itertools = "0.10"
|
||||
matrixcompare = { version = "0.2.0", features = [ "proptest-support" ] }
|
||||
nalgebra = { version="0.25", path = "../", features = ["compare"] }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
# Enable certain features when building docs for docs.rs
|
||||
features = [ "proptest-support", "compare" ]
|
|
@ -0,0 +1,124 @@
|
|||
use crate::convert::serial::*;
|
||||
use crate::coo::CooMatrix;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
use nalgebra::storage::Storage;
|
||||
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||
use num_traits::Zero;
|
||||
|
||||
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CooMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||
convert_dense_coo(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CooMatrix<T>> for DMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(coo: &'a CooMatrix<T>) -> Self {
|
||||
convert_coo_dense(coo)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CooMatrix<T>> for CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||
convert_coo_csr(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CsrMatrix<T>> for CooMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||
convert_csr_coo(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||
convert_dense_csr(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CsrMatrix<T>> for DMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||
convert_csr_dense(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CooMatrix<T>> for CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CooMatrix<T>) -> Self {
|
||||
convert_coo_csc(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CscMatrix<T>> for CooMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||
convert_csc_coo(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, R, C, S> From<&'a Matrix<T, R, C, S>> for CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
fn from(matrix: &'a Matrix<T, R, C, S>) -> Self {
|
||||
convert_dense_csc(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CscMatrix<T>> for DMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||
convert_csc_dense(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CscMatrix<T>> for CsrMatrix<T>
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
fn from(matrix: &'a CscMatrix<T>) -> Self {
|
||||
convert_csc_csr(matrix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> From<&'a CsrMatrix<T>> for CscMatrix<T>
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
fn from(matrix: &'a CsrMatrix<T>) -> Self {
|
||||
convert_csr_csc(matrix)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
//! Routines for converting between sparse matrix formats.
|
||||
//!
|
||||
//! Most users should instead use the provided `From` implementations to convert between matrix
|
||||
//! formats. Note that `From` implementations may not be available between all combinations of
|
||||
//! sparse matrices.
|
||||
//!
|
||||
//! The following example illustrates how to convert between matrix formats with the `From`
|
||||
//! implementations.
|
||||
//!
|
||||
//! ```rust
|
||||
//! use nalgebra_sparse::{csr::CsrMatrix, csc::CscMatrix, coo::CooMatrix};
|
||||
//! use nalgebra::DMatrix;
|
||||
//!
|
||||
//! // Conversion from dense
|
||||
//! let dense = DMatrix::<f64>::identity(9, 8);
|
||||
//! let csr = CsrMatrix::from(&dense);
|
||||
//! let csc = CscMatrix::from(&dense);
|
||||
//! let coo = CooMatrix::from(&dense);
|
||||
//!
|
||||
//! // CSR <-> CSC
|
||||
//! let _ = CsrMatrix::from(&csc);
|
||||
//! let _ = CscMatrix::from(&csr);
|
||||
//!
|
||||
//! // CSR <-> COO
|
||||
//! let _ = CooMatrix::from(&csr);
|
||||
//! let _ = CsrMatrix::from(&coo);
|
||||
//!
|
||||
//! // CSC <-> COO
|
||||
//! let _ = CooMatrix::from(&csc);
|
||||
//! let _ = CscMatrix::from(&coo);
|
||||
//! ```
|
||||
//!
|
||||
//! The routines available here are able to provide more specialized APIs, giving
|
||||
//! more control over the conversion process. The routines are organized by backends.
|
||||
//! Currently, only the [`serial`] backend is available.
|
||||
//! In the future, backends that offer parallel routines may become available.
|
||||
|
||||
pub mod serial;
|
||||
|
||||
mod impl_std_ops;
|
|
@ -0,0 +1,427 @@
|
|||
//! Serial routines for converting between matrix formats.
|
||||
//!
|
||||
//! All routines in this module are single-threaded. At present these routines offer no
|
||||
//! advantage over using the [`From`] trait, but future changes to the API might offer more
|
||||
//! control to the user.
|
||||
use std::ops::Add;
|
||||
|
||||
use num_traits::Zero;
|
||||
|
||||
use nalgebra::storage::Storage;
|
||||
use nalgebra::{ClosedAdd, DMatrix, Dim, Matrix, Scalar};
|
||||
|
||||
use crate::coo::CooMatrix;
|
||||
use crate::cs;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
|
||||
/// Converts a dense matrix to [`CooMatrix`].
|
||||
pub fn convert_dense_coo<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CooMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
let mut coo = CooMatrix::new(dense.nrows(), dense.ncols());
|
||||
|
||||
for (index, v) in dense.iter().enumerate() {
|
||||
if v != &T::zero() {
|
||||
// We use the fact that matrix iteration is guaranteed to be column-major
|
||||
let i = index % dense.nrows();
|
||||
let j = index / dense.nrows();
|
||||
coo.push(i, j, v.inlined_clone());
|
||||
}
|
||||
}
|
||||
|
||||
coo
|
||||
}
|
||||
|
||||
/// Converts a [`CooMatrix`] to a dense matrix.
|
||||
pub fn convert_coo_dense<T>(coo: &CooMatrix<T>) -> DMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero + ClosedAdd,
|
||||
{
|
||||
let mut output = DMatrix::repeat(coo.nrows(), coo.ncols(), T::zero());
|
||||
for (i, j, v) in coo.triplet_iter() {
|
||||
output[(i, j)] += v.inlined_clone();
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
/// Converts a [`CooMatrix`] to a [`CsrMatrix`].
|
||||
pub fn convert_coo_csr<T>(coo: &CooMatrix<T>) -> CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
let (offsets, indices, values) = convert_coo_cs(
|
||||
coo.nrows(),
|
||||
coo.row_indices(),
|
||||
coo.col_indices(),
|
||||
coo.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||
// to see if it can be justified for performance reasons)
|
||||
CsrMatrix::try_from_csr_data(coo.nrows(), coo.ncols(), offsets, indices, values)
|
||||
.expect("Internal error: Invalid CSR data during COO->CSR conversion")
|
||||
}
|
||||
|
||||
/// Converts a [`CsrMatrix`] to a [`CooMatrix`].
|
||||
pub fn convert_csr_coo<T: Scalar>(csr: &CsrMatrix<T>) -> CooMatrix<T> {
|
||||
let mut result = CooMatrix::new(csr.nrows(), csr.ncols());
|
||||
for (i, j, v) in csr.triplet_iter() {
|
||||
result.push(i, j, v.inlined_clone());
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Converts a [`CsrMatrix`] to a dense matrix.
|
||||
pub fn convert_csr_dense<T>(csr: &CsrMatrix<T>) -> DMatrix<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + Zero,
|
||||
{
|
||||
let mut output = DMatrix::zeros(csr.nrows(), csr.ncols());
|
||||
|
||||
for (i, j, v) in csr.triplet_iter() {
|
||||
output[(i, j)] += v.inlined_clone();
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Converts a dense matrix to a [`CsrMatrix`].
|
||||
pub fn convert_dense_csr<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CsrMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
let mut row_offsets = Vec::with_capacity(dense.nrows() + 1);
|
||||
let mut col_idx = Vec::new();
|
||||
let mut values = Vec::new();
|
||||
|
||||
// We have to iterate row-by-row to build the CSR matrix, which is at odds with
|
||||
// nalgebra's column-major storage. The alternative would be to perform an initial sweep
|
||||
// to count number of non-zeros per row.
|
||||
row_offsets.push(0);
|
||||
for i in 0..dense.nrows() {
|
||||
for j in 0..dense.ncols() {
|
||||
let v = dense.index((i, j));
|
||||
if v != &T::zero() {
|
||||
col_idx.push(j);
|
||||
values.push(v.inlined_clone());
|
||||
}
|
||||
}
|
||||
row_offsets.push(col_idx.len());
|
||||
}
|
||||
|
||||
// TODO: Consider circumventing the data validity check here
|
||||
// (would require unsafe, should benchmark)
|
||||
CsrMatrix::try_from_csr_data(dense.nrows(), dense.ncols(), row_offsets, col_idx, values)
|
||||
.expect("Internal error: Invalid CsrMatrix format during dense-> CSR conversion")
|
||||
}
|
||||
|
||||
/// Converts a [`CooMatrix`] to a [`CscMatrix`].
|
||||
pub fn convert_coo_csc<T>(coo: &CooMatrix<T>) -> CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
let (offsets, indices, values) = convert_coo_cs(
|
||||
coo.ncols(),
|
||||
coo.col_indices(),
|
||||
coo.row_indices(),
|
||||
coo.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid "try_from" since it validates the data? (requires unsafe, should benchmark
|
||||
// to see if it can be justified for performance reasons)
|
||||
CscMatrix::try_from_csc_data(coo.nrows(), coo.ncols(), offsets, indices, values)
|
||||
.expect("Internal error: Invalid CSC data during COO->CSC conversion")
|
||||
}
|
||||
|
||||
/// Converts a [`CscMatrix`] to a [`CooMatrix`].
|
||||
pub fn convert_csc_coo<T>(csc: &CscMatrix<T>) -> CooMatrix<T>
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
let mut coo = CooMatrix::new(csc.nrows(), csc.ncols());
|
||||
for (i, j, v) in csc.triplet_iter() {
|
||||
coo.push(i, j, v.inlined_clone());
|
||||
}
|
||||
coo
|
||||
}
|
||||
|
||||
/// Converts a [`CscMatrix`] to a dense matrix.
|
||||
pub fn convert_csc_dense<T>(csc: &CscMatrix<T>) -> DMatrix<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + Zero,
|
||||
{
|
||||
let mut output = DMatrix::zeros(csc.nrows(), csc.ncols());
|
||||
|
||||
for (i, j, v) in csc.triplet_iter() {
|
||||
output[(i, j)] += v.inlined_clone();
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Converts a dense matrix to a [`CscMatrix`].
|
||||
pub fn convert_dense_csc<T, R, C, S>(dense: &Matrix<T, R, C, S>) -> CscMatrix<T>
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
{
|
||||
let mut col_offsets = Vec::with_capacity(dense.ncols() + 1);
|
||||
let mut row_idx = Vec::new();
|
||||
let mut values = Vec::new();
|
||||
|
||||
col_offsets.push(0);
|
||||
for j in 0..dense.ncols() {
|
||||
for i in 0..dense.nrows() {
|
||||
let v = dense.index((i, j));
|
||||
if v != &T::zero() {
|
||||
row_idx.push(i);
|
||||
values.push(v.inlined_clone());
|
||||
}
|
||||
}
|
||||
col_offsets.push(row_idx.len());
|
||||
}
|
||||
|
||||
// TODO: Consider circumventing the data validity check here
|
||||
// (would require unsafe, should benchmark)
|
||||
CscMatrix::try_from_csc_data(dense.nrows(), dense.ncols(), col_offsets, row_idx, values)
|
||||
.expect("Internal error: Invalid CscMatrix format during dense-> CSC conversion")
|
||||
}
|
||||
|
||||
/// Converts a [`CsrMatrix`] to a [`CscMatrix`].
|
||||
pub fn convert_csr_csc<T>(csr: &CsrMatrix<T>) -> CscMatrix<T>
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
let (offsets, indices, values) = cs::transpose_cs(
|
||||
csr.nrows(),
|
||||
csr.ncols(),
|
||||
csr.row_offsets(),
|
||||
csr.col_indices(),
|
||||
csr.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid data validity check?
|
||||
CscMatrix::try_from_csc_data(csr.nrows(), csr.ncols(), offsets, indices, values)
|
||||
.expect("Internal error: Invalid CSC data during CSR->CSC conversion")
|
||||
}
|
||||
|
||||
/// Converts a [`CscMatrix`] to a [`CsrMatrix`].
|
||||
pub fn convert_csc_csr<T>(csc: &CscMatrix<T>) -> CsrMatrix<T>
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
let (offsets, indices, values) = cs::transpose_cs(
|
||||
csc.ncols(),
|
||||
csc.nrows(),
|
||||
csc.col_offsets(),
|
||||
csc.row_indices(),
|
||||
csc.values(),
|
||||
);
|
||||
|
||||
// TODO: Avoid data validity check?
|
||||
CsrMatrix::try_from_csr_data(csc.nrows(), csc.ncols(), offsets, indices, values)
|
||||
.expect("Internal error: Invalid CSR data during CSC->CSR conversion")
|
||||
}
|
||||
|
||||
fn convert_coo_cs<T>(
|
||||
major_dim: usize,
|
||||
major_indices: &[usize],
|
||||
minor_indices: &[usize],
|
||||
values: &[T],
|
||||
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||
where
|
||||
T: Scalar + Zero,
|
||||
{
|
||||
assert_eq!(major_indices.len(), minor_indices.len());
|
||||
assert_eq!(minor_indices.len(), values.len());
|
||||
let nnz = major_indices.len();
|
||||
|
||||
let (unsorted_major_offsets, unsorted_minor_idx, unsorted_vals) = {
|
||||
let mut offsets = vec![0usize; major_dim + 1];
|
||||
let mut minor_idx = vec![0usize; nnz];
|
||||
let mut vals = vec![T::zero(); nnz];
|
||||
coo_to_unsorted_cs(
|
||||
&mut offsets,
|
||||
&mut minor_idx,
|
||||
&mut vals,
|
||||
major_dim,
|
||||
major_indices,
|
||||
minor_indices,
|
||||
values,
|
||||
);
|
||||
(offsets, minor_idx, vals)
|
||||
};
|
||||
|
||||
// TODO: If input is sorted and/or without duplicates, we can avoid additional allocations
|
||||
// and work. Might want to take advantage of this.
|
||||
|
||||
// At this point, assembly is essentially complete. However, we must ensure
|
||||
// that minor indices are sorted within each lane and without duplicates.
|
||||
let mut sorted_major_offsets = Vec::new();
|
||||
let mut sorted_minor_idx = Vec::new();
|
||||
let mut sorted_vals = Vec::new();
|
||||
|
||||
sorted_major_offsets.push(0);
|
||||
|
||||
// We need some temporary storage when working with each lane. Since lanes often have a
|
||||
// very small number of non-zero entries, we try to amortize allocations across
|
||||
// lanes by reusing workspace vectors
|
||||
let mut idx_workspace = Vec::new();
|
||||
let mut perm_workspace = Vec::new();
|
||||
let mut values_workspace = Vec::new();
|
||||
|
||||
for lane in 0..major_dim {
|
||||
let begin = unsorted_major_offsets[lane];
|
||||
let end = unsorted_major_offsets[lane + 1];
|
||||
let count = end - begin;
|
||||
let range = begin..end;
|
||||
|
||||
// Ensure that workspaces can hold enough data
|
||||
perm_workspace.resize(count, 0);
|
||||
idx_workspace.resize(count, 0);
|
||||
values_workspace.resize(count, T::zero());
|
||||
sort_lane(
|
||||
&mut idx_workspace[..count],
|
||||
&mut values_workspace[..count],
|
||||
&unsorted_minor_idx[range.clone()],
|
||||
&unsorted_vals[range.clone()],
|
||||
&mut perm_workspace[..count],
|
||||
);
|
||||
|
||||
let sorted_ja_current_len = sorted_minor_idx.len();
|
||||
|
||||
combine_duplicates(
|
||||
|idx| sorted_minor_idx.push(idx),
|
||||
|val| sorted_vals.push(val),
|
||||
&idx_workspace[..count],
|
||||
&values_workspace[..count],
|
||||
&Add::add,
|
||||
);
|
||||
|
||||
let new_col_count = sorted_minor_idx.len() - sorted_ja_current_len;
|
||||
sorted_major_offsets.push(sorted_major_offsets.last().unwrap() + new_col_count);
|
||||
}
|
||||
|
||||
(sorted_major_offsets, sorted_minor_idx, sorted_vals)
|
||||
}
|
||||
|
||||
/// Converts matrix data given in triplet format to unsorted CSR/CSC, retaining any duplicated
|
||||
/// indices.
|
||||
///
|
||||
/// Here `major/minor` is `row/col` for CSR and `col/row` for CSC.
|
||||
fn coo_to_unsorted_cs<T: Clone>(
|
||||
major_offsets: &mut [usize],
|
||||
cs_minor_idx: &mut [usize],
|
||||
cs_values: &mut [T],
|
||||
major_dim: usize,
|
||||
major_indices: &[usize],
|
||||
minor_indices: &[usize],
|
||||
coo_values: &[T],
|
||||
) {
|
||||
assert_eq!(major_offsets.len(), major_dim + 1);
|
||||
assert_eq!(cs_minor_idx.len(), cs_values.len());
|
||||
assert_eq!(cs_values.len(), major_indices.len());
|
||||
assert_eq!(major_indices.len(), minor_indices.len());
|
||||
assert_eq!(minor_indices.len(), coo_values.len());
|
||||
|
||||
// Count the number of occurrences of each row
|
||||
for major_idx in major_indices {
|
||||
major_offsets[*major_idx] += 1;
|
||||
}
|
||||
|
||||
cs::convert_counts_to_offsets(major_offsets);
|
||||
|
||||
{
|
||||
// TODO: Instead of allocating a whole new vector storing the current counts,
|
||||
// I think it's possible to be a bit more clever by storing each count
|
||||
// in the last of the column indices for each row
|
||||
let mut current_counts = vec![0usize; major_dim + 1];
|
||||
let triplet_iter = major_indices.iter().zip(minor_indices).zip(coo_values);
|
||||
for ((i, j), value) in triplet_iter {
|
||||
let current_offset = major_offsets[*i] + current_counts[*i];
|
||||
cs_minor_idx[current_offset] = *j;
|
||||
cs_values[current_offset] = value.clone();
|
||||
current_counts[*i] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Sort the indices of the given lane.
|
||||
///
|
||||
/// The indices and values in `minor_idx` and `values` are sorted according to the
|
||||
/// minor indices and stored in `minor_idx_result` and `values_result` respectively.
|
||||
///
|
||||
/// All input slices are expected to be of the same length. The contents of mutable slices
|
||||
/// can be arbitrary, as they are anyway overwritten.
|
||||
fn sort_lane<T: Clone>(
|
||||
minor_idx_result: &mut [usize],
|
||||
values_result: &mut [T],
|
||||
minor_idx: &[usize],
|
||||
values: &[T],
|
||||
workspace: &mut [usize],
|
||||
) {
|
||||
assert_eq!(minor_idx_result.len(), values_result.len());
|
||||
assert_eq!(values_result.len(), minor_idx.len());
|
||||
assert_eq!(minor_idx.len(), values.len());
|
||||
assert_eq!(values.len(), workspace.len());
|
||||
|
||||
let permutation = workspace;
|
||||
// Set permutation to identity
|
||||
for (i, p) in permutation.iter_mut().enumerate() {
|
||||
*p = i;
|
||||
}
|
||||
|
||||
// Compute permutation needed to bring minor indices into sorted order
|
||||
// Note: Using sort_unstable here avoids internal allocations, which is crucial since
|
||||
// each lane might have a small number of elements
|
||||
permutation.sort_unstable_by_key(|idx| minor_idx[*idx]);
|
||||
|
||||
apply_permutation(minor_idx_result, minor_idx, permutation);
|
||||
apply_permutation(values_result, values, permutation);
|
||||
}
|
||||
|
||||
// TODO: Move this into `utils` or something?
|
||||
fn apply_permutation<T: Clone>(out_slice: &mut [T], in_slice: &[T], permutation: &[usize]) {
|
||||
assert_eq!(out_slice.len(), in_slice.len());
|
||||
assert_eq!(out_slice.len(), permutation.len());
|
||||
for (out_element, old_pos) in out_slice.iter_mut().zip(permutation) {
|
||||
*out_element = in_slice[*old_pos].clone();
|
||||
}
|
||||
}
|
||||
|
||||
/// Given *sorted* indices and corresponding scalar values, combines duplicates with the given
|
||||
/// associative combiner and calls the provided produce methods with combined indices and values.
|
||||
fn combine_duplicates<T: Clone>(
|
||||
mut produce_idx: impl FnMut(usize),
|
||||
mut produce_value: impl FnMut(T),
|
||||
idx_array: &[usize],
|
||||
values: &[T],
|
||||
combiner: impl Fn(T, T) -> T,
|
||||
) {
|
||||
assert_eq!(idx_array.len(), values.len());
|
||||
|
||||
let mut i = 0;
|
||||
while i < idx_array.len() {
|
||||
let idx = idx_array[i];
|
||||
let mut combined_value = values[i].clone();
|
||||
let mut j = i + 1;
|
||||
while j < idx_array.len() && idx_array[j] == idx {
|
||||
let j_val = values[j].clone();
|
||||
combined_value = combiner(combined_value, j_val);
|
||||
j += 1;
|
||||
}
|
||||
produce_idx(idx);
|
||||
produce_value(combined_value);
|
||||
i = j;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,208 @@
|
|||
//! An implementation of the COO sparse matrix format.
|
||||
|
||||
use crate::SparseFormatError;
|
||||
|
||||
/// A COO representation of a sparse matrix.
|
||||
///
|
||||
/// A COO matrix stores entries in coordinate-form, that is triplets `(i, j, v)`, where `i` and `j`
|
||||
/// correspond to row and column indices of the entry, and `v` to the value of the entry.
|
||||
/// The format is of limited use for standard matrix operations. Its main purpose is to facilitate
|
||||
/// easy construction of other, more efficient matrix formats (such as CSR/COO), and the
|
||||
/// conversion between different formats.
|
||||
///
|
||||
/// # Format
|
||||
///
|
||||
/// For given dimensions `nrows` and `ncols`, the matrix is represented by three same-length
|
||||
/// arrays `row_indices`, `col_indices` and `values` that constitute the coordinate triplets
|
||||
/// of the matrix. The indices must be in bounds, but *duplicate entries are explicitly allowed*.
|
||||
/// Upon conversion to other formats, the duplicate entries may be summed together. See the
|
||||
/// documentation for the respective conversion functions.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix, csc::CscMatrix};
|
||||
///
|
||||
/// // Initialize a matrix with all zeros (no explicitly stored entries).
|
||||
/// let mut coo = CooMatrix::new(4, 4);
|
||||
/// // Or initialize it with a set of triplets
|
||||
/// coo = CooMatrix::try_from_triplets(4, 4, vec![1, 2], vec![0, 1], vec![3.0, 4.0]).unwrap();
|
||||
///
|
||||
/// // Push a few triplets
|
||||
/// coo.push(2, 0, 1.0);
|
||||
/// coo.push(0, 1, 2.0);
|
||||
///
|
||||
/// // Convert to other matrix formats
|
||||
/// let csr = CsrMatrix::from(&coo);
|
||||
/// let csc = CscMatrix::from(&coo);
|
||||
/// ```
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CooMatrix<T> {
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
row_indices: Vec<usize>,
|
||||
col_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> CooMatrix<T> {
|
||||
/// Construct a zero COO matrix of the given dimensions.
|
||||
///
|
||||
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
|
||||
/// is empty, so that the matrix (implicitly) represented by the COO matrix consists of all
|
||||
/// zero entries.
|
||||
pub fn new(nrows: usize, ncols: usize) -> Self {
|
||||
Self {
|
||||
nrows,
|
||||
ncols,
|
||||
row_indices: Vec::new(),
|
||||
col_indices: Vec::new(),
|
||||
values: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct a zero COO matrix of the given dimensions.
|
||||
///
|
||||
/// Specifically, the collection of triplets - corresponding to explicitly stored entries -
|
||||
/// is empty, so that the matrix (implicitly) represented by the COO matrix consists of all
|
||||
/// zero entries.
|
||||
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||
Self::new(nrows, ncols)
|
||||
}
|
||||
|
||||
/// Try to construct a COO matrix from the given dimensions and a collection of
|
||||
/// (i, j, v) triplets.
|
||||
///
|
||||
/// Returns an error if either row or column indices contain indices out of bounds,
|
||||
/// or if the data arrays do not all have the same length. Note that the COO format
|
||||
/// inherently supports duplicate entries.
|
||||
pub fn try_from_triplets(
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
row_indices: Vec<usize>,
|
||||
col_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
use crate::SparseFormatErrorKind::*;
|
||||
if row_indices.len() != col_indices.len() {
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
InvalidStructure,
|
||||
"Number of row and col indices must be the same.",
|
||||
));
|
||||
} else if col_indices.len() != values.len() {
|
||||
return Err(SparseFormatError::from_kind_and_msg(
|
||||
InvalidStructure,
|
||||
"Number of col indices and values must be the same.",
|
||||
));
|
||||
}
|
||||
|
||||
let row_indices_in_bounds = row_indices.iter().all(|i| *i < nrows);
|
||||
let col_indices_in_bounds = col_indices.iter().all(|j| *j < ncols);
|
||||
|
||||
if !row_indices_in_bounds {
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
IndexOutOfBounds,
|
||||
"Row index out of bounds.",
|
||||
))
|
||||
} else if !col_indices_in_bounds {
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
IndexOutOfBounds,
|
||||
"Col index out of bounds.",
|
||||
))
|
||||
} else {
|
||||
Ok(Self {
|
||||
nrows,
|
||||
ncols,
|
||||
row_indices,
|
||||
col_indices,
|
||||
values,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over triplets (i, j, v).
|
||||
// TODO: Consider giving the iterator a concrete type instead of impl trait...?
|
||||
pub fn triplet_iter(&self) -> impl Iterator<Item = (usize, usize, &T)> {
|
||||
self.row_indices
|
||||
.iter()
|
||||
.zip(&self.col_indices)
|
||||
.zip(&self.values)
|
||||
.map(|((i, j), v)| (*i, *j, v))
|
||||
}
|
||||
|
||||
/// Push a single triplet to the matrix.
|
||||
///
|
||||
/// This adds the value `v` to the `i`th row and `j`th column in the matrix.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
///
|
||||
/// Panics if `i` or `j` is out of bounds.
|
||||
#[inline]
|
||||
pub fn push(&mut self, i: usize, j: usize, v: T) {
|
||||
assert!(i < self.nrows);
|
||||
assert!(j < self.ncols);
|
||||
self.row_indices.push(i);
|
||||
self.col_indices.push(j);
|
||||
self.values.push(v);
|
||||
}
|
||||
|
||||
/// The number of rows in the matrix.
|
||||
#[inline]
|
||||
pub fn nrows(&self) -> usize {
|
||||
self.nrows
|
||||
}
|
||||
|
||||
/// The number of columns in the matrix.
|
||||
#[inline]
|
||||
pub fn ncols(&self) -> usize {
|
||||
self.ncols
|
||||
}
|
||||
|
||||
/// The number of explicitly stored entries in the matrix.
|
||||
///
|
||||
/// This number *includes* duplicate entries. For example, if the `CooMatrix` contains duplicate
|
||||
/// entries, then it may have a different number of non-zeros as reported by `nnz()` compared
|
||||
/// to its CSR representation.
|
||||
#[inline]
|
||||
pub fn nnz(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
/// The row indices of the explicitly stored entries.
|
||||
pub fn row_indices(&self) -> &[usize] {
|
||||
&self.row_indices
|
||||
}
|
||||
|
||||
/// The column indices of the explicitly stored entries.
|
||||
pub fn col_indices(&self) -> &[usize] {
|
||||
&self.col_indices
|
||||
}
|
||||
|
||||
/// The values of the explicitly stored entries.
|
||||
pub fn values(&self) -> &[T] {
|
||||
&self.values
|
||||
}
|
||||
|
||||
/// Disassembles the matrix into individual triplet arrays.
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
///
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::coo::CooMatrix;
|
||||
/// let row_indices = vec![0, 1];
|
||||
/// let col_indices = vec![1, 2];
|
||||
/// let values = vec![1.0, 2.0];
|
||||
/// let coo = CooMatrix::try_from_triplets(2, 3, row_indices, col_indices, values)
|
||||
/// .unwrap();
|
||||
///
|
||||
/// let (row_idx, col_idx, val) = coo.disassemble();
|
||||
/// assert_eq!(row_idx, vec![0, 1]);
|
||||
/// assert_eq!(col_idx, vec![1, 2]);
|
||||
/// assert_eq!(val, vec![1.0, 2.0]);
|
||||
/// ```
|
||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||
(self.row_indices, self.col_indices, self.values)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,530 @@
|
|||
use std::mem::replace;
|
||||
use std::ops::Range;
|
||||
|
||||
use num_traits::One;
|
||||
|
||||
use nalgebra::Scalar;
|
||||
|
||||
use crate::pattern::SparsityPattern;
|
||||
use crate::{SparseEntry, SparseEntryMut};
|
||||
|
||||
/// An abstract compressed matrix.
|
||||
///
|
||||
/// For the time being, this is only used internally to share implementation between
|
||||
/// CSR and CSC matrices.
|
||||
///
|
||||
/// A CSR matrix is obtained by associating rows with the major dimension, while a CSC matrix
|
||||
/// is obtained by associating columns with the major dimension.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsMatrix<T> {
|
||||
sparsity_pattern: SparsityPattern,
|
||||
values: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> CsMatrix<T> {
|
||||
/// Create a zero matrix with no explicitly stored entries.
|
||||
#[inline]
|
||||
pub fn new(major_dim: usize, minor_dim: usize) -> Self {
|
||||
Self {
|
||||
sparsity_pattern: SparsityPattern::zeros(major_dim, minor_dim),
|
||||
values: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn pattern(&self) -> &SparsityPattern {
|
||||
&self.sparsity_pattern
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn values(&self) -> &[T] {
|
||||
&self.values
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn values_mut(&mut self) -> &mut [T] {
|
||||
&mut self.values
|
||||
}
|
||||
|
||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||
#[inline]
|
||||
pub fn cs_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||
let pattern = self.pattern();
|
||||
(
|
||||
pattern.major_offsets(),
|
||||
pattern.minor_indices(),
|
||||
&self.values,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the raw data represented as a tuple `(major_offsets, minor_indices, values)`.
|
||||
#[inline]
|
||||
pub fn cs_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
let pattern = &mut self.sparsity_pattern;
|
||||
(
|
||||
pattern.major_offsets(),
|
||||
pattern.minor_indices(),
|
||||
&mut self.values,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||
(&self.sparsity_pattern, &mut self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn from_pattern_and_values(pattern: SparsityPattern, values: Vec<T>) -> Self {
|
||||
assert_eq!(
|
||||
pattern.nnz(),
|
||||
values.len(),
|
||||
"Internal error: consumers should verify shape compatibility."
|
||||
);
|
||||
Self {
|
||||
sparsity_pattern: pattern,
|
||||
values,
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal method for simplifying access to a lane's data
|
||||
#[inline]
|
||||
pub fn get_index_range(&self, row_index: usize) -> Option<Range<usize>> {
|
||||
let row_begin = *self.sparsity_pattern.major_offsets().get(row_index)?;
|
||||
let row_end = *self.sparsity_pattern.major_offsets().get(row_index + 1)?;
|
||||
Some(row_begin..row_end)
|
||||
}
|
||||
|
||||
pub fn take_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||
(self.sparsity_pattern, self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||
let (offsets, indices) = self.sparsity_pattern.disassemble();
|
||||
(offsets, indices, self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||
(self.sparsity_pattern, self.values)
|
||||
}
|
||||
|
||||
/// Returns an entry for the given major/minor indices, or `None` if the indices are out
|
||||
/// of bounds.
|
||||
pub fn get_entry(&self, major_index: usize, minor_index: usize) -> Option<SparseEntry<T>> {
|
||||
let row_range = self.get_index_range(major_index)?;
|
||||
let (_, minor_indices, values) = self.cs_data();
|
||||
let minor_indices = &minor_indices[row_range.clone()];
|
||||
let values = &values[row_range];
|
||||
get_entry_from_slices(
|
||||
self.pattern().minor_dim(),
|
||||
minor_indices,
|
||||
values,
|
||||
minor_index,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given major/minor indices, or `None` if the indices are out
|
||||
/// of bounds.
|
||||
pub fn get_entry_mut(
|
||||
&mut self,
|
||||
major_index: usize,
|
||||
minor_index: usize,
|
||||
) -> Option<SparseEntryMut<T>> {
|
||||
let row_range = self.get_index_range(major_index)?;
|
||||
let minor_dim = self.pattern().minor_dim();
|
||||
let (_, minor_indices, values) = self.cs_data_mut();
|
||||
let minor_indices = &minor_indices[row_range.clone()];
|
||||
let values = &mut values[row_range];
|
||||
get_mut_entry_from_slices(minor_dim, minor_indices, values, minor_index)
|
||||
}
|
||||
|
||||
pub fn get_lane(&self, index: usize) -> Option<CsLane<T>> {
|
||||
let range = self.get_index_range(index)?;
|
||||
let (_, minor_indices, values) = self.cs_data();
|
||||
Some(CsLane {
|
||||
minor_indices: &minor_indices[range.clone()],
|
||||
values: &values[range],
|
||||
minor_dim: self.pattern().minor_dim(),
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_lane_mut(&mut self, index: usize) -> Option<CsLaneMut<T>> {
|
||||
let range = self.get_index_range(index)?;
|
||||
let minor_dim = self.pattern().minor_dim();
|
||||
let (_, minor_indices, values) = self.cs_data_mut();
|
||||
Some(CsLaneMut {
|
||||
minor_dim,
|
||||
minor_indices: &minor_indices[range.clone()],
|
||||
values: &mut values[range],
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn lane_iter(&self) -> CsLaneIter<T> {
|
||||
CsLaneIter::new(self.pattern(), self.values())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn lane_iter_mut(&mut self) -> CsLaneIterMut<T> {
|
||||
CsLaneIterMut::new(&self.sparsity_pattern, &mut self.values)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool,
|
||||
{
|
||||
let (major_dim, minor_dim) = (self.pattern().major_dim(), self.pattern().minor_dim());
|
||||
let mut new_offsets = Vec::with_capacity(self.pattern().major_dim() + 1);
|
||||
let mut new_indices = Vec::new();
|
||||
let mut new_values = Vec::new();
|
||||
|
||||
new_offsets.push(0);
|
||||
for (i, lane) in self.lane_iter().enumerate() {
|
||||
for (&j, value) in lane.minor_indices().iter().zip(lane.values) {
|
||||
if predicate(i, j, value) {
|
||||
new_indices.push(j);
|
||||
new_values.push(value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
new_offsets.push(new_indices.len());
|
||||
}
|
||||
|
||||
// TODO: Avoid checks here
|
||||
let new_pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
major_dim,
|
||||
minor_dim,
|
||||
new_offsets,
|
||||
new_indices,
|
||||
)
|
||||
.expect("Internal error: Sparsity pattern must always be valid.");
|
||||
|
||||
Self::from_pattern_and_values(new_pattern, new_values)
|
||||
}
|
||||
|
||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_matrix(&self) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
// TODO: This might be faster with a binary search for each diagonal entry
|
||||
self.filter(|i, j, _| i == j)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Scalar + One> CsMatrix<T> {
|
||||
#[inline]
|
||||
pub fn identity(n: usize) -> Self {
|
||||
let offsets: Vec<_> = (0..=n).collect();
|
||||
let indices: Vec<_> = (0..n).collect();
|
||||
let values = vec![T::one(); n];
|
||||
|
||||
// TODO: We should skip checks here
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(n, n, offsets, indices).unwrap();
|
||||
Self::from_pattern_and_values(pattern, values)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_entry_from_slices<'a, T>(
|
||||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a [T],
|
||||
global_minor_index: usize,
|
||||
) -> Option<SparseEntry<'a, T>> {
|
||||
let local_index = minor_indices.binary_search(&global_minor_index);
|
||||
if let Ok(local_index) = local_index {
|
||||
Some(SparseEntry::NonZero(&values[local_index]))
|
||||
} else if global_minor_index < minor_dim {
|
||||
Some(SparseEntry::Zero)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mut_entry_from_slices<'a, T>(
|
||||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a mut [T],
|
||||
global_minor_indices: usize,
|
||||
) -> Option<SparseEntryMut<'a, T>> {
|
||||
let local_index = minor_indices.binary_search(&global_minor_indices);
|
||||
if let Ok(local_index) = local_index {
|
||||
Some(SparseEntryMut::NonZero(&mut values[local_index]))
|
||||
} else if global_minor_indices < minor_dim {
|
||||
Some(SparseEntryMut::Zero)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsLane<'a, T> {
|
||||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a [T],
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct CsLaneMut<'a, T> {
|
||||
minor_dim: usize,
|
||||
minor_indices: &'a [usize],
|
||||
values: &'a mut [T],
|
||||
}
|
||||
|
||||
pub struct CsLaneIter<'a, T> {
|
||||
// The index of the lane that will be returned on the next iteration
|
||||
current_lane_idx: usize,
|
||||
pattern: &'a SparsityPattern,
|
||||
remaining_values: &'a [T],
|
||||
}
|
||||
|
||||
impl<'a, T> CsLaneIter<'a, T> {
|
||||
pub fn new(pattern: &'a SparsityPattern, values: &'a [T]) -> Self {
|
||||
Self {
|
||||
current_lane_idx: 0,
|
||||
pattern,
|
||||
remaining_values: values,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsLaneIter<'a, T>
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CsLane<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let lane = self.pattern.get_lane(self.current_lane_idx);
|
||||
let minor_dim = self.pattern.minor_dim();
|
||||
|
||||
if let Some(minor_indices) = lane {
|
||||
let count = minor_indices.len();
|
||||
let values_in_lane = &self.remaining_values[..count];
|
||||
self.remaining_values = &self.remaining_values[count..];
|
||||
self.current_lane_idx += 1;
|
||||
|
||||
Some(CsLane {
|
||||
minor_dim,
|
||||
minor_indices,
|
||||
values: values_in_lane,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CsLaneIterMut<'a, T> {
|
||||
// The index of the lane that will be returned on the next iteration
|
||||
current_lane_idx: usize,
|
||||
pattern: &'a SparsityPattern,
|
||||
remaining_values: &'a mut [T],
|
||||
}
|
||||
|
||||
impl<'a, T> CsLaneIterMut<'a, T> {
|
||||
pub fn new(pattern: &'a SparsityPattern, values: &'a mut [T]) -> Self {
|
||||
Self {
|
||||
current_lane_idx: 0,
|
||||
pattern,
|
||||
remaining_values: values,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsLaneIterMut<'a, T>
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CsLaneMut<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let lane = self.pattern.get_lane(self.current_lane_idx);
|
||||
let minor_dim = self.pattern.minor_dim();
|
||||
|
||||
if let Some(minor_indices) = lane {
|
||||
let count = minor_indices.len();
|
||||
|
||||
let remaining = replace(&mut self.remaining_values, &mut []);
|
||||
let (values_in_lane, remaining) = remaining.split_at_mut(count);
|
||||
self.remaining_values = remaining;
|
||||
self.current_lane_idx += 1;
|
||||
|
||||
Some(CsLaneMut {
|
||||
minor_dim,
|
||||
minor_indices,
|
||||
values: values_in_lane,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implement the methods common to both CsLane and CsLaneMut. See the documentation for the
|
||||
/// methods delegated here by CsrMatrix and CscMatrix members for more information.
|
||||
macro_rules! impl_cs_lane_common_methods {
|
||||
($name:ty) => {
|
||||
impl<'a, T> $name {
|
||||
#[inline]
|
||||
pub fn minor_dim(&self) -> usize {
|
||||
self.minor_dim
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn nnz(&self) -> usize {
|
||||
self.minor_indices.len()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn minor_indices(&self) -> &[usize] {
|
||||
self.minor_indices
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn values(&self) -> &[T] {
|
||||
self.values
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||
get_entry_from_slices(
|
||||
self.minor_dim,
|
||||
self.minor_indices,
|
||||
self.values,
|
||||
global_col_index,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_cs_lane_common_methods!(CsLane<'a, T>);
|
||||
impl_cs_lane_common_methods!(CsLaneMut<'a, T>);
|
||||
|
||||
impl<'a, T> CsLaneMut<'a, T> {
|
||||
pub fn values_mut(&mut self) -> &mut [T] {
|
||||
self.values
|
||||
}
|
||||
|
||||
pub fn indices_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||
(self.minor_indices, self.values)
|
||||
}
|
||||
|
||||
pub fn get_entry_mut(&mut self, global_minor_index: usize) -> Option<SparseEntryMut<T>> {
|
||||
get_mut_entry_from_slices(
|
||||
self.minor_dim,
|
||||
self.minor_indices,
|
||||
self.values,
|
||||
global_minor_index,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper struct for working with uninitialized data in vectors.
|
||||
/// TODO: This doesn't belong here.
|
||||
struct UninitVec<T> {
|
||||
vec: Vec<T>,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl<T> UninitVec<T> {
|
||||
pub fn from_len(len: usize) -> Self {
|
||||
Self {
|
||||
vec: Vec::with_capacity(len),
|
||||
// We need to store len separately, because for zero-sized types,
|
||||
// Vec::with_capacity(len) does not give vec.capacity() == len
|
||||
len,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the element associated with the given index to the provided value.
|
||||
///
|
||||
/// Must be called exactly once per index, otherwise results in undefined behavior.
|
||||
pub unsafe fn set(&mut self, index: usize, value: T) {
|
||||
self.vec.as_mut_ptr().add(index).write(value)
|
||||
}
|
||||
|
||||
/// Marks the vector data as initialized by returning a full vector.
|
||||
///
|
||||
/// It is undefined behavior to call this function unless *all* elements have been written to
|
||||
/// exactly once.
|
||||
pub unsafe fn assume_init(mut self) -> Vec<T> {
|
||||
self.vec.set_len(self.len);
|
||||
self.vec
|
||||
}
|
||||
}
|
||||
|
||||
/// Transposes the compressed format.
|
||||
///
|
||||
/// This means that major and minor roles are switched. This is used for converting between CSR
|
||||
/// and CSC formats.
|
||||
pub fn transpose_cs<T>(
|
||||
major_dim: usize,
|
||||
minor_dim: usize,
|
||||
source_major_offsets: &[usize],
|
||||
source_minor_indices: &[usize],
|
||||
values: &[T],
|
||||
) -> (Vec<usize>, Vec<usize>, Vec<T>)
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
assert_eq!(source_major_offsets.len(), major_dim + 1);
|
||||
assert_eq!(source_minor_indices.len(), values.len());
|
||||
let nnz = values.len();
|
||||
|
||||
// Count the number of occurences of each minor index
|
||||
let mut minor_counts = vec![0; minor_dim];
|
||||
for minor_idx in source_minor_indices {
|
||||
minor_counts[*minor_idx] += 1;
|
||||
}
|
||||
convert_counts_to_offsets(&mut minor_counts);
|
||||
let mut target_offsets = minor_counts;
|
||||
target_offsets.push(nnz);
|
||||
let mut target_indices = vec![usize::MAX; nnz];
|
||||
|
||||
// We have to use uninitialized storage, because we don't have any kind of "default" value
|
||||
// available for `T`. Unfortunately this necessitates some small amount of unsafe code
|
||||
let mut target_values = UninitVec::from_len(nnz);
|
||||
|
||||
// Keep track of how many entries we have placed in each target major lane
|
||||
let mut current_target_major_counts = vec![0; minor_dim];
|
||||
|
||||
for source_major_idx in 0..major_dim {
|
||||
let source_lane_begin = source_major_offsets[source_major_idx];
|
||||
let source_lane_end = source_major_offsets[source_major_idx + 1];
|
||||
let source_lane_indices = &source_minor_indices[source_lane_begin..source_lane_end];
|
||||
let source_lane_values = &values[source_lane_begin..source_lane_end];
|
||||
|
||||
for (&source_minor_idx, val) in source_lane_indices.iter().zip(source_lane_values) {
|
||||
// Compute the offset in the target data for this particular source entry
|
||||
let target_lane_count = &mut current_target_major_counts[source_minor_idx];
|
||||
let entry_offset = target_offsets[source_minor_idx] + *target_lane_count;
|
||||
target_indices[entry_offset] = source_major_idx;
|
||||
unsafe {
|
||||
target_values.set(entry_offset, val.inlined_clone());
|
||||
}
|
||||
*target_lane_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, we should have written to each element in target_values exactly once,
|
||||
// so initialization should be sound
|
||||
let target_values = unsafe { target_values.assume_init() };
|
||||
(target_offsets, target_indices, target_values)
|
||||
}
|
||||
|
||||
pub fn convert_counts_to_offsets(counts: &mut [usize]) {
|
||||
// Convert the counts to an offset
|
||||
let mut offset = 0;
|
||||
for i_offset in counts.iter_mut() {
|
||||
let count = *i_offset;
|
||||
*i_offset = offset;
|
||||
offset += count;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,704 @@
|
|||
//! An implementation of the CSC sparse matrix format.
|
||||
//!
|
||||
//! This is the module-level documentation. See [`CscMatrix`] for the main documentation of the
|
||||
//! CSC implementation.
|
||||
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::One;
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
/// A CSC representation of a sparse matrix.
|
||||
///
|
||||
/// The Compressed Sparse Column (CSC) format is well-suited as a general-purpose storage format
|
||||
/// for many sparse matrix applications.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// ```rust
|
||||
/// use nalgebra_sparse::csc::CscMatrix;
|
||||
/// use nalgebra::{DMatrix, Matrix3x4};
|
||||
/// use matrixcompare::assert_matrix_eq;
|
||||
///
|
||||
/// // The sparsity patterns of CSC matrices are immutable. This means that you cannot dynamically
|
||||
/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
|
||||
/// // way to construct a CSC matrix is to first incrementally construct a COO matrix,
|
||||
/// // and then convert it to CSC.
|
||||
/// # use nalgebra_sparse::coo::CooMatrix;
|
||||
/// # let coo = CooMatrix::<f64>::new(3, 3);
|
||||
/// let csc = CscMatrix::from(&coo);
|
||||
///
|
||||
/// // Alternatively, a CSC matrix can be constructed directly from raw CSC data.
|
||||
/// // Here, we construct a 3x4 matrix
|
||||
/// let col_offsets = vec![0, 1, 3, 4, 5];
|
||||
/// let row_indices = vec![0, 0, 2, 2, 0];
|
||||
/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
///
|
||||
/// // The dense representation of the CSC data, for comparison
|
||||
/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 5.0,
|
||||
/// 0.0, 0.0, 0.0, 0.0,
|
||||
/// 0.0, 3.0, 4.0, 0.0);
|
||||
///
|
||||
/// // The constructor validates the raw CSC data and returns an error if it is invalid.
|
||||
/// let csc = CscMatrix::try_from_csc_data(3, 4, col_offsets, row_indices, values)
|
||||
/// .expect("CSC data must conform to format specifications");
|
||||
/// assert_matrix_eq!(csc, dense);
|
||||
///
|
||||
/// // A third approach is to construct a CSC matrix from a pattern and values. Sometimes this is
|
||||
/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
|
||||
/// let (pattern, values) = csc.into_pattern_and_values();
|
||||
/// let csc = CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||
/// .expect("The pattern and values must be compatible");
|
||||
///
|
||||
/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
|
||||
/// // other CSC matrices and dense matrices/vectors.
|
||||
/// let x = csc;
|
||||
/// # #[allow(non_snake_case)]
|
||||
/// let xTx = x.transpose() * &x;
|
||||
/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
|
||||
/// let w = 3.0 * xTx * z;
|
||||
///
|
||||
/// // Although the sparsity pattern of a CSC matrix cannot be changed, its values can.
|
||||
/// // Here are two different ways to scale all values by a constant:
|
||||
/// let mut x = x;
|
||||
/// x *= 5.0;
|
||||
/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
|
||||
/// ```
|
||||
///
|
||||
/// # Format
|
||||
///
|
||||
/// An `m x n` sparse matrix with `nnz` non-zeros in CSC format is represented by the
|
||||
/// following three arrays:
|
||||
///
|
||||
/// - `col_offsets`, an array of integers with length `n + 1`.
|
||||
/// - `row_indices`, an array of integers with length `nnz`.
|
||||
/// - `values`, an array of values with length `nnz`.
|
||||
///
|
||||
/// The relationship between the arrays is described below.
|
||||
///
|
||||
/// - Each consecutive pair of entries `col_offsets[j] .. col_offsets[j + 1]` corresponds to an
|
||||
/// offset range in `row_indices` that holds the row indices in column `j`.
|
||||
/// - For an entry represented by the index `idx`, `row_indices[idx]` stores its column index and
|
||||
/// `values[idx]` stores its value.
|
||||
///
|
||||
/// The following invariants must be upheld and are enforced by the data structure:
|
||||
///
|
||||
/// - `col_offsets[0] == 0`
|
||||
/// - `col_offsets[m] == nnz`
|
||||
/// - `col_offsets` is monotonically increasing.
|
||||
/// - `0 <= row_indices[idx] < m` for all `idx < nnz`.
|
||||
/// - The row indices associated with each column are monotonically increasing (see below).
|
||||
///
|
||||
/// The CSC format is a standard sparse matrix format (see [Wikipedia article]). The format
|
||||
/// represents the matrix in a column-by-column fashion. The entries associated with column `j` are
|
||||
/// determined as follows:
|
||||
///
|
||||
/// ```rust
|
||||
/// # let col_offsets: Vec<usize> = vec![0, 0];
|
||||
/// # let row_indices: Vec<usize> = vec![];
|
||||
/// # let values: Vec<i32> = vec![];
|
||||
/// # let j = 0;
|
||||
/// let range = col_offsets[j] .. col_offsets[j + 1];
|
||||
/// let col_j_rows = &row_indices[range.clone()];
|
||||
/// let col_j_vals = &values[range];
|
||||
///
|
||||
/// // For each pair (i, v) in (col_j_rows, col_j_vals), we obtain a corresponding entry
|
||||
/// // (i, j, v) in the matrix.
|
||||
/// assert_eq!(col_j_rows.len(), col_j_vals.len());
|
||||
/// ```
|
||||
///
|
||||
/// In the above example, for each column `j`, the row indices `col_j_cols` must appear in
|
||||
/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
|
||||
/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
|
||||
/// assumption for both correctness and performance for many algorithms.
|
||||
///
|
||||
/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
|
||||
/// column-by-column instead of row-by-row like CSR.
|
||||
///
|
||||
/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_(CSC_or_CCS)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CscMatrix<T> {
|
||||
// Cols are major, rows are minor in the sparsity pattern
|
||||
pub(crate) cs: CsMatrix<T>,
|
||||
}
|
||||
|
||||
impl<T> CscMatrix<T> {
|
||||
/// Constructs a CSC representation of the (square) `n x n` identity matrix.
|
||||
#[inline]
|
||||
pub fn identity(n: usize) -> Self
|
||||
where
|
||||
T: Scalar + One,
|
||||
{
|
||||
Self {
|
||||
cs: CsMatrix::identity(n),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a zero CSC matrix with no explicitly stored entries.
|
||||
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||
Self {
|
||||
cs: CsMatrix::new(ncols, nrows),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to construct a CSC matrix from raw CSC data.
|
||||
///
|
||||
/// It is assumed that each column contains unique and sorted row indices that are in
|
||||
/// bounds with respect to the number of rows in the matrix. If this is not the case,
|
||||
/// an error is returned to indicate the failure.
|
||||
///
|
||||
/// An error is returned if the data given does not conform to the CSC storage format.
|
||||
/// See the documentation for [CscMatrix](struct.CscMatrix.html) for more information.
|
||||
pub fn try_from_csc_data(
|
||||
num_rows: usize,
|
||||
num_cols: usize,
|
||||
col_offsets: Vec<usize>,
|
||||
row_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_cols,
|
||||
num_rows,
|
||||
col_offsets,
|
||||
row_indices,
|
||||
)
|
||||
.map_err(pattern_format_error_to_csc_error)?;
|
||||
Self::try_from_pattern_and_values(pattern, values)
|
||||
}
|
||||
|
||||
/// Try to construct a CSC matrix from a sparsity pattern and associated non-zero values.
|
||||
///
|
||||
/// Returns an error if the number of values does not match the number of minor indices
|
||||
/// in the pattern.
|
||||
pub fn try_from_pattern_and_values(
|
||||
pattern: SparsityPattern,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
if pattern.nnz() == values.len() {
|
||||
Ok(Self {
|
||||
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||
})
|
||||
} else {
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"Number of values and row indices must be the same",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// The number of rows in the matrix.
|
||||
#[inline]
|
||||
pub fn nrows(&self) -> usize {
|
||||
self.cs.pattern().minor_dim()
|
||||
}
|
||||
|
||||
/// The number of columns in the matrix.
|
||||
#[inline]
|
||||
pub fn ncols(&self) -> usize {
|
||||
self.cs.pattern().major_dim()
|
||||
}
|
||||
|
||||
/// The number of non-zeros in the matrix.
|
||||
///
|
||||
/// Note that this corresponds to the number of explicitly stored entries, *not* the actual
|
||||
/// number of algebraically zero entries in the matrix. Explicitly stored entries can still
|
||||
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
||||
#[inline]
|
||||
pub fn nnz(&self) -> usize {
|
||||
self.pattern().nnz()
|
||||
}
|
||||
|
||||
/// The column offsets defining part of the CSC format.
|
||||
#[inline]
|
||||
pub fn col_offsets(&self) -> &[usize] {
|
||||
self.pattern().major_offsets()
|
||||
}
|
||||
|
||||
/// The row indices defining part of the CSC format.
|
||||
#[inline]
|
||||
pub fn row_indices(&self) -> &[usize] {
|
||||
self.pattern().minor_indices()
|
||||
}
|
||||
|
||||
/// The non-zero values defining part of the CSC format.
|
||||
#[inline]
|
||||
pub fn values(&self) -> &[T] {
|
||||
self.cs.values()
|
||||
}
|
||||
|
||||
/// Mutable access to the non-zero values.
|
||||
#[inline]
|
||||
pub fn values_mut(&mut self) -> &mut [T] {
|
||||
self.cs.values_mut()
|
||||
}
|
||||
|
||||
/// An iterator over non-zero triplets (i, j, v).
|
||||
///
|
||||
/// The iteration happens in column-major fashion, meaning that j increases monotonically,
|
||||
/// and i increases monotonically within each row.
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::csc::CscMatrix;
|
||||
/// let col_offsets = vec![0, 2, 3, 4];
|
||||
/// let row_indices = vec![0, 2, 1, 0];
|
||||
/// let values = vec![1, 3, 2, 4];
|
||||
/// let mut csc = CscMatrix::try_from_csc_data(4, 3, col_offsets, row_indices, values)
|
||||
/// .unwrap();
|
||||
///
|
||||
/// let triplets: Vec<_> = csc.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 3), (1, 1, 2), (0, 2, 4)]);
|
||||
/// ```
|
||||
pub fn triplet_iter(&self) -> CscTripletIter<T> {
|
||||
CscTripletIter {
|
||||
pattern_iter: self.pattern().entries(),
|
||||
values_iter: self.values().iter(),
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable iterator over non-zero triplets (i, j, v).
|
||||
///
|
||||
/// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::csc::CscMatrix;
|
||||
/// let col_offsets = vec![0, 2, 3, 4];
|
||||
/// let row_indices = vec![0, 2, 1, 0];
|
||||
/// let values = vec![1, 3, 2, 4];
|
||||
/// // Using the same data as in the `triplet_iter` example
|
||||
/// let mut csc = CscMatrix::try_from_csc_data(4, 3, col_offsets, row_indices, values)
|
||||
/// .unwrap();
|
||||
///
|
||||
/// // Zero out lower-triangular terms
|
||||
/// csc.triplet_iter_mut()
|
||||
/// .filter(|(i, j, _)| j < i)
|
||||
/// .for_each(|(_, _, v)| *v = 0);
|
||||
///
|
||||
/// let triplets: Vec<_> = csc.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||
/// assert_eq!(triplets, vec![(0, 0, 1), (2, 0, 0), (1, 1, 2), (0, 2, 4)]);
|
||||
/// ```
|
||||
pub fn triplet_iter_mut(&mut self) -> CscTripletIterMut<T> {
|
||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||
CscTripletIterMut {
|
||||
pattern_iter: pattern.entries(),
|
||||
values_mut_iter: values.iter_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the column at the given column index.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if column index is out of bounds.
|
||||
#[inline]
|
||||
pub fn col(&self, index: usize) -> CscCol<T> {
|
||||
self.get_col(index).expect("Row index must be in bounds")
|
||||
}
|
||||
|
||||
/// Mutable column access for the given column index.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if column index is out of bounds.
|
||||
#[inline]
|
||||
pub fn col_mut(&mut self, index: usize) -> CscColMut<T> {
|
||||
self.get_col_mut(index)
|
||||
.expect("Row index must be in bounds")
|
||||
}
|
||||
|
||||
/// Return the column at the given column index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_col(&self, index: usize) -> Option<CscCol<T>> {
|
||||
self.cs.get_lane(index).map(|lane| CscCol { lane })
|
||||
}
|
||||
|
||||
/// Mutable column access for the given column index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_col_mut(&mut self, index: usize) -> Option<CscColMut<T>> {
|
||||
self.cs.get_lane_mut(index).map(|lane| CscColMut { lane })
|
||||
}
|
||||
|
||||
/// An iterator over columns in the matrix.
|
||||
pub fn col_iter(&self) -> CscColIter<T> {
|
||||
CscColIter {
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable iterator over columns in the matrix.
|
||||
pub fn col_iter_mut(&mut self) -> CscColIterMut<T> {
|
||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||
CscColIterMut {
|
||||
lane_iter: CsLaneIterMut::new(pattern, values),
|
||||
}
|
||||
}
|
||||
|
||||
/// Disassembles the CSC matrix into its underlying offset, index and value arrays.
|
||||
///
|
||||
/// If the matrix contains the sole reference to the sparsity pattern,
|
||||
/// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
///
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::csc::CscMatrix;
|
||||
/// let col_offsets = vec![0, 2, 3, 4];
|
||||
/// let row_indices = vec![0, 2, 1, 0];
|
||||
/// let values = vec![1, 3, 2, 4];
|
||||
/// let mut csc = CscMatrix::try_from_csc_data(
|
||||
/// 4,
|
||||
/// 3,
|
||||
/// col_offsets.clone(),
|
||||
/// row_indices.clone(),
|
||||
/// values.clone())
|
||||
/// .unwrap();
|
||||
/// let (col_offsets2, row_indices2, values2) = csc.disassemble();
|
||||
/// assert_eq!(col_offsets2, col_offsets);
|
||||
/// assert_eq!(row_indices2, row_indices);
|
||||
/// assert_eq!(values2, values);
|
||||
/// ```
|
||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||
self.cs.disassemble()
|
||||
}
|
||||
|
||||
/// Returns the sparsity pattern and values associated with this matrix.
|
||||
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||
self.cs.into_pattern_and_values()
|
||||
}
|
||||
|
||||
/// Returns a reference to the sparsity pattern and a mutable reference to the values.
|
||||
#[inline]
|
||||
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||
self.cs.pattern_and_values_mut()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying sparsity pattern.
|
||||
pub fn pattern(&self) -> &SparsityPattern {
|
||||
self.cs.pattern()
|
||||
}
|
||||
|
||||
/// Reinterprets the CSC matrix as its transpose represented by a CSR matrix.
|
||||
///
|
||||
/// This operation does not touch the CSC data, and is effectively a no-op.
|
||||
pub fn transpose_as_csr(self) -> CsrMatrix<T> {
|
||||
let (pattern, values) = self.cs.take_pattern_and_values();
|
||||
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||
}
|
||||
|
||||
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored row entries for the given column.
|
||||
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
||||
self.cs.get_entry(col_index, row_index)
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
||||
/// of bounds.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored row entries for the given column.
|
||||
pub fn get_entry_mut(
|
||||
&mut self,
|
||||
row_index: usize,
|
||||
col_index: usize,
|
||||
) -> Option<SparseEntryMut<T>> {
|
||||
self.cs.get_entry_mut(col_index, row_index)
|
||||
}
|
||||
|
||||
/// Returns an entry for the given row/col indices.
|
||||
///
|
||||
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
|
||||
/// out of bounds.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
|
||||
self.get_entry(row_index, col_index)
|
||||
.expect("Out of bounds matrix indices encountered")
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given row/col indices.
|
||||
///
|
||||
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
|
||||
/// out of bounds.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
|
||||
self.get_entry_mut(row_index, col_index)
|
||||
.expect("Out of bounds matrix indices encountered")
|
||||
}
|
||||
|
||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSC data.
|
||||
pub fn csc_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||
self.cs.cs_data()
|
||||
}
|
||||
|
||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSC data,
|
||||
/// where the `values` array is mutable.
|
||||
pub fn csc_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
self.cs.cs_data_mut()
|
||||
}
|
||||
|
||||
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||
/// given predicate.
|
||||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool,
|
||||
{
|
||||
// Note: Predicate uses (row, col, value), so we have to switch around since
|
||||
// cs uses (major, minor, value)
|
||||
Self {
|
||||
cs: self
|
||||
.cs
|
||||
.filter(|col_idx, row_idx, v| predicate(row_idx, col_idx, v)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn upper_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i <= j)
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the lower triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn lower_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i >= j)
|
||||
}
|
||||
|
||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_csc(&self) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
Self {
|
||||
cs: self.cs.diagonal_as_matrix(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the transpose of the matrix.
|
||||
pub fn transpose(&self) -> CscMatrix<T>
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
CsrMatrix::from(self).transpose_as_csc()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert pattern format errors into more meaningful CSC-specific errors.
|
||||
///
|
||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||
/// not lanes, major and minor dimensions.
|
||||
fn pattern_format_error_to_csc_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||
use SparseFormatError as E;
|
||||
use SparseFormatErrorKind as K;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of col offset array is not equal to ncols + 1.",
|
||||
),
|
||||
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"First or last col offset is inconsistent with format specification.",
|
||||
),
|
||||
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Col offsets are not monotonically increasing.",
|
||||
),
|
||||
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Row indices are not monotonically increasing (sorted) within each column.",
|
||||
),
|
||||
MinorIndexOutOfBounds => {
|
||||
E::from_kind_and_msg(K::IndexOutOfBounds, "Row indices are out of bounds.")
|
||||
}
|
||||
PatternDuplicateEntry => {
|
||||
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator type for iterating over triplets in a CSC matrix.
|
||||
#[derive(Debug)]
|
||||
pub struct CscTripletIter<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_iter: Iter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone> CscTripletIter<'a, T> {
|
||||
/// Adapts the triplet iterator to return owned values.
|
||||
///
|
||||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||
/// so that the values are cloned.
|
||||
#[inline]
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscTripletIter<'a, T> {
|
||||
type Item = (usize, usize, &'a T);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let next_entry = self.pattern_iter.next();
|
||||
let next_value = self.values_iter.next();
|
||||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator type for mutably iterating over triplets in a CSC matrix.
|
||||
#[derive(Debug)]
|
||||
pub struct CscTripletIterMut<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_mut_iter: IterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscTripletIterMut<'a, T> {
|
||||
type Item = (usize, usize, &'a mut T);
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let next_entry = self.pattern_iter.next();
|
||||
let next_value = self.values_mut_iter.next();
|
||||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((j, i, v)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An immutable representation of a column in a CSC matrix.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CscCol<'a, T> {
|
||||
lane: CsLane<'a, T>,
|
||||
}
|
||||
|
||||
/// A mutable representation of a column in a CSC matrix.
|
||||
///
|
||||
/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
|
||||
/// to the column cannot be modified.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct CscColMut<'a, T> {
|
||||
lane: CsLaneMut<'a, T>,
|
||||
}
|
||||
|
||||
/// Implement the methods common to both CscCol and CscColMut
|
||||
macro_rules! impl_csc_col_common_methods {
|
||||
($name:ty) => {
|
||||
impl<'a, T> $name {
|
||||
/// The number of global rows in the column.
|
||||
#[inline]
|
||||
pub fn nrows(&self) -> usize {
|
||||
self.lane.minor_dim()
|
||||
}
|
||||
|
||||
/// The number of non-zeros in this column.
|
||||
#[inline]
|
||||
pub fn nnz(&self) -> usize {
|
||||
self.lane.nnz()
|
||||
}
|
||||
|
||||
/// The row indices corresponding to explicitly stored entries in this column.
|
||||
#[inline]
|
||||
pub fn row_indices(&self) -> &[usize] {
|
||||
self.lane.minor_indices()
|
||||
}
|
||||
|
||||
/// The values corresponding to explicitly stored entries in this column.
|
||||
#[inline]
|
||||
pub fn values(&self) -> &[T] {
|
||||
self.lane.values()
|
||||
}
|
||||
|
||||
/// Returns an entry for the given global row index.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored row entries.
|
||||
pub fn get_entry(&self, global_row_index: usize) -> Option<SparseEntry<T>> {
|
||||
self.lane.get_entry(global_row_index)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_csc_col_common_methods!(CscCol<'a, T>);
|
||||
impl_csc_col_common_methods!(CscColMut<'a, T>);
|
||||
|
||||
impl<'a, T> CscColMut<'a, T> {
|
||||
/// Mutable access to the values corresponding to explicitly stored entries in this column.
|
||||
pub fn values_mut(&mut self) -> &mut [T] {
|
||||
self.lane.values_mut()
|
||||
}
|
||||
|
||||
/// Provides simultaneous access to row indices and mutable values corresponding to the
|
||||
/// explicitly stored entries in this column.
|
||||
///
|
||||
/// This method primarily facilitates low-level access for methods that process data stored
|
||||
/// in CSC format directly.
|
||||
pub fn rows_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||
self.lane.indices_and_values_mut()
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given global row index.
|
||||
pub fn get_entry_mut(&mut self, global_row_index: usize) -> Option<SparseEntryMut<T>> {
|
||||
self.lane.get_entry_mut(global_row_index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||
pub struct CscColIter<'a, T> {
|
||||
lane_iter: CsLaneIter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscColIter<'a, T> {
|
||||
type Item = CscCol<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter.next().map(|lane| CscCol { lane })
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable column iterator for [CscMatrix](struct.CscMatrix.html).
|
||||
pub struct CscColIterMut<'a, T> {
|
||||
lane_iter: CsLaneIterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CscColIterMut<'a, T>
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CscColMut<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter.next().map(|lane| CscColMut { lane })
|
||||
}
|
||||
}
|
|
@ -0,0 +1,708 @@
|
|||
//! An implementation of the CSR sparse matrix format.
|
||||
//!
|
||||
//! This is the module-level documentation. See [`CsrMatrix`] for the main documentation of the
|
||||
//! CSC implementation.
|
||||
use crate::cs::{CsLane, CsLaneIter, CsLaneIterMut, CsLaneMut, CsMatrix};
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::pattern::{SparsityPattern, SparsityPatternFormatError, SparsityPatternIter};
|
||||
use crate::{SparseEntry, SparseEntryMut, SparseFormatError, SparseFormatErrorKind};
|
||||
|
||||
use nalgebra::Scalar;
|
||||
use num_traits::One;
|
||||
|
||||
use std::slice::{Iter, IterMut};
|
||||
|
||||
/// A CSR representation of a sparse matrix.
|
||||
///
|
||||
/// The Compressed Sparse Row (CSR) format is well-suited as a general-purpose storage format
|
||||
/// for many sparse matrix applications.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// ```rust
|
||||
/// use nalgebra_sparse::csr::CsrMatrix;
|
||||
/// use nalgebra::{DMatrix, Matrix3x4};
|
||||
/// use matrixcompare::assert_matrix_eq;
|
||||
///
|
||||
/// // The sparsity patterns of CSR matrices are immutable. This means that you cannot dynamically
|
||||
/// // change the sparsity pattern of the matrix after it has been constructed. The easiest
|
||||
/// // way to construct a CSR matrix is to first incrementally construct a COO matrix,
|
||||
/// // and then convert it to CSR.
|
||||
/// # use nalgebra_sparse::coo::CooMatrix;
|
||||
/// # let coo = CooMatrix::<f64>::new(3, 3);
|
||||
/// let csr = CsrMatrix::from(&coo);
|
||||
///
|
||||
/// // Alternatively, a CSR matrix can be constructed directly from raw CSR data.
|
||||
/// // Here, we construct a 3x4 matrix
|
||||
/// let row_offsets = vec![0, 3, 3, 5];
|
||||
/// let col_indices = vec![0, 1, 3, 1, 2];
|
||||
/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
///
|
||||
/// // The dense representation of the CSR data, for comparison
|
||||
/// let dense = Matrix3x4::new(1.0, 2.0, 0.0, 3.0,
|
||||
/// 0.0, 0.0, 0.0, 0.0,
|
||||
/// 0.0, 4.0, 5.0, 0.0);
|
||||
///
|
||||
/// // The constructor validates the raw CSR data and returns an error if it is invalid.
|
||||
/// let csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||
/// .expect("CSR data must conform to format specifications");
|
||||
/// assert_matrix_eq!(csr, dense);
|
||||
///
|
||||
/// // A third approach is to construct a CSR matrix from a pattern and values. Sometimes this is
|
||||
/// // useful if the sparsity pattern is constructed separately from the values of the matrix.
|
||||
/// let (pattern, values) = csr.into_pattern_and_values();
|
||||
/// let csr = CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||
/// .expect("The pattern and values must be compatible");
|
||||
///
|
||||
/// // Once we have constructed our matrix, we can use it for arithmetic operations together with
|
||||
/// // other CSR matrices and dense matrices/vectors.
|
||||
/// let x = csr;
|
||||
/// # #[allow(non_snake_case)]
|
||||
/// let xTx = x.transpose() * &x;
|
||||
/// let z = DMatrix::from_fn(4, 8, |i, j| (i as f64) * (j as f64));
|
||||
/// let w = 3.0 * xTx * z;
|
||||
///
|
||||
/// // Although the sparsity pattern of a CSR matrix cannot be changed, its values can.
|
||||
/// // Here are two different ways to scale all values by a constant:
|
||||
/// let mut x = x;
|
||||
/// x *= 5.0;
|
||||
/// x.values_mut().iter_mut().for_each(|x_i| *x_i *= 5.0);
|
||||
/// ```
|
||||
///
|
||||
/// # Format
|
||||
///
|
||||
/// An `m x n` sparse matrix with `nnz` non-zeros in CSR format is represented by the
|
||||
/// following three arrays:
|
||||
///
|
||||
/// - `row_offsets`, an array of integers with length `m + 1`.
|
||||
/// - `col_indices`, an array of integers with length `nnz`.
|
||||
/// - `values`, an array of values with length `nnz`.
|
||||
///
|
||||
/// The relationship between the arrays is described below.
|
||||
///
|
||||
/// - Each consecutive pair of entries `row_offsets[i] .. row_offsets[i + 1]` corresponds to an
|
||||
/// offset range in `col_indices` that holds the column indices in row `i`.
|
||||
/// - For an entry represented by the index `idx`, `col_indices[idx]` stores its column index and
|
||||
/// `values[idx]` stores its value.
|
||||
///
|
||||
/// The following invariants must be upheld and are enforced by the data structure:
|
||||
///
|
||||
/// - `row_offsets[0] == 0`
|
||||
/// - `row_offsets[m] == nnz`
|
||||
/// - `row_offsets` is monotonically increasing.
|
||||
/// - `0 <= col_indices[idx] < n` for all `idx < nnz`.
|
||||
/// - The column indices associated with each row are monotonically increasing (see below).
|
||||
///
|
||||
/// The CSR format is a standard sparse matrix format (see [Wikipedia article]). The format
|
||||
/// represents the matrix in a row-by-row fashion. The entries associated with row `i` are
|
||||
/// determined as follows:
|
||||
///
|
||||
/// ```rust
|
||||
/// # let row_offsets: Vec<usize> = vec![0, 0];
|
||||
/// # let col_indices: Vec<usize> = vec![];
|
||||
/// # let values: Vec<i32> = vec![];
|
||||
/// # let i = 0;
|
||||
/// let range = row_offsets[i] .. row_offsets[i + 1];
|
||||
/// let row_i_cols = &col_indices[range.clone()];
|
||||
/// let row_i_vals = &values[range];
|
||||
///
|
||||
/// // For each pair (j, v) in (row_i_cols, row_i_vals), we obtain a corresponding entry
|
||||
/// // (i, j, v) in the matrix.
|
||||
/// assert_eq!(row_i_cols.len(), row_i_vals.len());
|
||||
/// ```
|
||||
///
|
||||
/// In the above example, for each row `i`, the column indices `row_i_cols` must appear in
|
||||
/// monotonically increasing order. In other words, they must be *sorted*. This criterion is not
|
||||
/// standard among all sparse matrix libraries, but we enforce this property as it is a crucial
|
||||
/// assumption for both correctness and performance for many algorithms.
|
||||
///
|
||||
/// Note that the CSR and CSC formats are essentially identical, except that CSC stores the matrix
|
||||
/// column-by-column instead of row-by-row like CSR.
|
||||
///
|
||||
/// [Wikipedia article]: https://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_row_(CSR,_CRS_or_Yale_format)
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsrMatrix<T> {
|
||||
// Rows are major, cols are minor in the sparsity pattern
|
||||
pub(crate) cs: CsMatrix<T>,
|
||||
}
|
||||
|
||||
impl<T> CsrMatrix<T> {
|
||||
/// Constructs a CSR representation of the (square) `n x n` identity matrix.
|
||||
#[inline]
|
||||
pub fn identity(n: usize) -> Self
|
||||
where
|
||||
T: Scalar + One,
|
||||
{
|
||||
Self {
|
||||
cs: CsMatrix::identity(n),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a zero CSR matrix with no explicitly stored entries.
|
||||
pub fn zeros(nrows: usize, ncols: usize) -> Self {
|
||||
Self {
|
||||
cs: CsMatrix::new(nrows, ncols),
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to construct a CSR matrix from raw CSR data.
|
||||
///
|
||||
/// It is assumed that each row contains unique and sorted column indices that are in
|
||||
/// bounds with respect to the number of columns in the matrix. If this is not the case,
|
||||
/// an error is returned to indicate the failure.
|
||||
///
|
||||
/// An error is returned if the data given does not conform to the CSR storage format.
|
||||
/// See the documentation for [CsrMatrix](struct.CsrMatrix.html) for more information.
|
||||
pub fn try_from_csr_data(
|
||||
num_rows: usize,
|
||||
num_cols: usize,
|
||||
row_offsets: Vec<usize>,
|
||||
col_indices: Vec<usize>,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
num_rows,
|
||||
num_cols,
|
||||
row_offsets,
|
||||
col_indices,
|
||||
)
|
||||
.map_err(pattern_format_error_to_csr_error)?;
|
||||
Self::try_from_pattern_and_values(pattern, values)
|
||||
}
|
||||
|
||||
/// Try to construct a CSR matrix from a sparsity pattern and associated non-zero values.
|
||||
///
|
||||
/// Returns an error if the number of values does not match the number of minor indices
|
||||
/// in the pattern.
|
||||
pub fn try_from_pattern_and_values(
|
||||
pattern: SparsityPattern,
|
||||
values: Vec<T>,
|
||||
) -> Result<Self, SparseFormatError> {
|
||||
if pattern.nnz() == values.len() {
|
||||
Ok(Self {
|
||||
cs: CsMatrix::from_pattern_and_values(pattern, values),
|
||||
})
|
||||
} else {
|
||||
Err(SparseFormatError::from_kind_and_msg(
|
||||
SparseFormatErrorKind::InvalidStructure,
|
||||
"Number of values and column indices must be the same",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// The number of rows in the matrix.
|
||||
#[inline]
|
||||
pub fn nrows(&self) -> usize {
|
||||
self.cs.pattern().major_dim()
|
||||
}
|
||||
|
||||
/// The number of columns in the matrix.
|
||||
#[inline]
|
||||
pub fn ncols(&self) -> usize {
|
||||
self.cs.pattern().minor_dim()
|
||||
}
|
||||
|
||||
/// The number of non-zeros in the matrix.
|
||||
///
|
||||
/// Note that this corresponds to the number of explicitly stored entries, *not* the actual
|
||||
/// number of algebraically zero entries in the matrix. Explicitly stored entries can still
|
||||
/// be zero. Corresponds to the number of entries in the sparsity pattern.
|
||||
#[inline]
|
||||
pub fn nnz(&self) -> usize {
|
||||
self.cs.pattern().nnz()
|
||||
}
|
||||
|
||||
/// The row offsets defining part of the CSR format.
|
||||
#[inline]
|
||||
pub fn row_offsets(&self) -> &[usize] {
|
||||
let (offsets, _, _) = self.cs.cs_data();
|
||||
offsets
|
||||
}
|
||||
|
||||
/// The column indices defining part of the CSR format.
|
||||
#[inline]
|
||||
pub fn col_indices(&self) -> &[usize] {
|
||||
let (_, indices, _) = self.cs.cs_data();
|
||||
indices
|
||||
}
|
||||
|
||||
/// The non-zero values defining part of the CSR format.
|
||||
#[inline]
|
||||
pub fn values(&self) -> &[T] {
|
||||
self.cs.values()
|
||||
}
|
||||
|
||||
/// Mutable access to the non-zero values.
|
||||
#[inline]
|
||||
pub fn values_mut(&mut self) -> &mut [T] {
|
||||
self.cs.values_mut()
|
||||
}
|
||||
|
||||
/// An iterator over non-zero triplets (i, j, v).
|
||||
///
|
||||
/// The iteration happens in row-major fashion, meaning that i increases monotonically,
|
||||
/// and j increases monotonically within each row.
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::csr::CsrMatrix;
|
||||
/// let row_offsets = vec![0, 2, 3, 4];
|
||||
/// let col_indices = vec![0, 2, 1, 0];
|
||||
/// let values = vec![1, 2, 3, 4];
|
||||
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||
/// .unwrap();
|
||||
///
|
||||
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 4)]);
|
||||
/// ```
|
||||
pub fn triplet_iter(&self) -> CsrTripletIter<T> {
|
||||
CsrTripletIter {
|
||||
pattern_iter: self.pattern().entries(),
|
||||
values_iter: self.values().iter(),
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable iterator over non-zero triplets (i, j, v).
|
||||
///
|
||||
/// Iteration happens in the same order as for [triplet_iter](#method.triplet_iter).
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::csr::CsrMatrix;
|
||||
/// # let row_offsets = vec![0, 2, 3, 4];
|
||||
/// # let col_indices = vec![0, 2, 1, 0];
|
||||
/// # let values = vec![1, 2, 3, 4];
|
||||
/// // Using the same data as in the `triplet_iter` example
|
||||
/// let mut csr = CsrMatrix::try_from_csr_data(3, 4, row_offsets, col_indices, values)
|
||||
/// .unwrap();
|
||||
///
|
||||
/// // Zero out lower-triangular terms
|
||||
/// csr.triplet_iter_mut()
|
||||
/// .filter(|(i, j, _)| j < i)
|
||||
/// .for_each(|(_, _, v)| *v = 0);
|
||||
///
|
||||
/// let triplets: Vec<_> = csr.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||
/// assert_eq!(triplets, vec![(0, 0, 1), (0, 2, 2), (1, 1, 3), (2, 0, 0)]);
|
||||
/// ```
|
||||
pub fn triplet_iter_mut(&mut self) -> CsrTripletIterMut<T> {
|
||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||
CsrTripletIterMut {
|
||||
pattern_iter: pattern.entries(),
|
||||
values_mut_iter: values.iter_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the row at the given row index.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if row index is out of bounds.
|
||||
#[inline]
|
||||
pub fn row(&self, index: usize) -> CsrRow<T> {
|
||||
self.get_row(index).expect("Row index must be in bounds")
|
||||
}
|
||||
|
||||
/// Mutable row access for the given row index.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if row index is out of bounds.
|
||||
#[inline]
|
||||
pub fn row_mut(&mut self, index: usize) -> CsrRowMut<T> {
|
||||
self.get_row_mut(index)
|
||||
.expect("Row index must be in bounds")
|
||||
}
|
||||
|
||||
/// Return the row at the given row index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_row(&self, index: usize) -> Option<CsrRow<T>> {
|
||||
self.cs.get_lane(index).map(|lane| CsrRow { lane })
|
||||
}
|
||||
|
||||
/// Mutable row access for the given row index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_row_mut(&mut self, index: usize) -> Option<CsrRowMut<T>> {
|
||||
self.cs.get_lane_mut(index).map(|lane| CsrRowMut { lane })
|
||||
}
|
||||
|
||||
/// An iterator over rows in the matrix.
|
||||
pub fn row_iter(&self) -> CsrRowIter<T> {
|
||||
CsrRowIter {
|
||||
lane_iter: CsLaneIter::new(self.pattern(), self.values()),
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable iterator over rows in the matrix.
|
||||
pub fn row_iter_mut(&mut self) -> CsrRowIterMut<T> {
|
||||
let (pattern, values) = self.cs.pattern_and_values_mut();
|
||||
CsrRowIterMut {
|
||||
lane_iter: CsLaneIterMut::new(pattern, values),
|
||||
}
|
||||
}
|
||||
|
||||
/// Disassembles the CSR matrix into its underlying offset, index and value arrays.
|
||||
///
|
||||
/// If the matrix contains the sole reference to the sparsity pattern,
|
||||
/// then the data is returned as-is. Otherwise, the sparsity pattern is cloned.
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
///
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::csr::CsrMatrix;
|
||||
/// let row_offsets = vec![0, 2, 3, 4];
|
||||
/// let col_indices = vec![0, 2, 1, 0];
|
||||
/// let values = vec![1, 2, 3, 4];
|
||||
/// let mut csr = CsrMatrix::try_from_csr_data(
|
||||
/// 3,
|
||||
/// 4,
|
||||
/// row_offsets.clone(),
|
||||
/// col_indices.clone(),
|
||||
/// values.clone())
|
||||
/// .unwrap();
|
||||
/// let (row_offsets2, col_indices2, values2) = csr.disassemble();
|
||||
/// assert_eq!(row_offsets2, row_offsets);
|
||||
/// assert_eq!(col_indices2, col_indices);
|
||||
/// assert_eq!(values2, values);
|
||||
/// ```
|
||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
|
||||
self.cs.disassemble()
|
||||
}
|
||||
|
||||
/// Returns the sparsity pattern and values associated with this matrix.
|
||||
pub fn into_pattern_and_values(self) -> (SparsityPattern, Vec<T>) {
|
||||
self.cs.into_pattern_and_values()
|
||||
}
|
||||
|
||||
/// Returns a reference to the sparsity pattern and a mutable reference to the values.
|
||||
#[inline]
|
||||
pub fn pattern_and_values_mut(&mut self) -> (&SparsityPattern, &mut [T]) {
|
||||
self.cs.pattern_and_values_mut()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying sparsity pattern.
|
||||
pub fn pattern(&self) -> &SparsityPattern {
|
||||
self.cs.pattern()
|
||||
}
|
||||
|
||||
/// Reinterprets the CSR matrix as its transpose represented by a CSC matrix.
|
||||
///
|
||||
/// This operation does not touch the CSR data, and is effectively a no-op.
|
||||
pub fn transpose_as_csc(self) -> CscMatrix<T> {
|
||||
let (pattern, values) = self.cs.take_pattern_and_values();
|
||||
CscMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||
}
|
||||
|
||||
/// Returns an entry for the given row/col indices, or `None` if the indices are out of bounds.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries for the given row.
|
||||
pub fn get_entry(&self, row_index: usize, col_index: usize) -> Option<SparseEntry<T>> {
|
||||
self.cs.get_entry(row_index, col_index)
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given row/col indices, or `None` if the indices are out
|
||||
/// of bounds.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries for the given row.
|
||||
pub fn get_entry_mut(
|
||||
&mut self,
|
||||
row_index: usize,
|
||||
col_index: usize,
|
||||
) -> Option<SparseEntryMut<T>> {
|
||||
self.cs.get_entry_mut(row_index, col_index)
|
||||
}
|
||||
|
||||
/// Returns an entry for the given row/col indices.
|
||||
///
|
||||
/// Same as `get_entry`, except that it directly panics upon encountering row/col indices
|
||||
/// out of bounds.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||
pub fn index_entry(&self, row_index: usize, col_index: usize) -> SparseEntry<T> {
|
||||
self.get_entry(row_index, col_index)
|
||||
.expect("Out of bounds matrix indices encountered")
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given row/col indices.
|
||||
///
|
||||
/// Same as `get_entry_mut`, except that it directly panics upon encountering row/col indices
|
||||
/// out of bounds.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
/// Panics if `row_index` or `col_index` is out of bounds.
|
||||
pub fn index_entry_mut(&mut self, row_index: usize, col_index: usize) -> SparseEntryMut<T> {
|
||||
self.get_entry_mut(row_index, col_index)
|
||||
.expect("Out of bounds matrix indices encountered")
|
||||
}
|
||||
|
||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data.
|
||||
pub fn csr_data(&self) -> (&[usize], &[usize], &[T]) {
|
||||
self.cs.cs_data()
|
||||
}
|
||||
|
||||
/// Returns a triplet of slices `(row_offsets, col_indices, values)` that make up the CSR data,
|
||||
/// where the `values` array is mutable.
|
||||
pub fn csr_data_mut(&mut self) -> (&[usize], &[usize], &mut [T]) {
|
||||
self.cs.cs_data_mut()
|
||||
}
|
||||
|
||||
/// Creates a sparse matrix that contains only the explicit entries decided by the
|
||||
/// given predicate.
|
||||
pub fn filter<P>(&self, predicate: P) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
P: Fn(usize, usize, &T) -> bool,
|
||||
{
|
||||
Self {
|
||||
cs: self
|
||||
.cs
|
||||
.filter(|row_idx, col_idx, v| predicate(row_idx, col_idx, v)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the upper triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn upper_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i <= j)
|
||||
}
|
||||
|
||||
/// Returns a new matrix representing the lower triangular part of this matrix.
|
||||
///
|
||||
/// The result includes the diagonal of the matrix.
|
||||
pub fn lower_triangle(&self) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
self.filter(|i, j, _| i >= j)
|
||||
}
|
||||
|
||||
/// Returns the diagonal of the matrix as a sparse matrix.
|
||||
pub fn diagonal_as_csr(&self) -> Self
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
Self {
|
||||
cs: self.cs.diagonal_as_matrix(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the transpose of the matrix.
|
||||
pub fn transpose(&self) -> CsrMatrix<T>
|
||||
where
|
||||
T: Scalar,
|
||||
{
|
||||
CscMatrix::from(self).transpose_as_csr()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert pattern format errors into more meaningful CSR-specific errors.
|
||||
///
|
||||
/// This ensures that the terminology is consistent: we are talking about rows and columns,
|
||||
/// not lanes, major and minor dimensions.
|
||||
fn pattern_format_error_to_csr_error(err: SparsityPatternFormatError) -> SparseFormatError {
|
||||
use SparseFormatError as E;
|
||||
use SparseFormatErrorKind as K;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
match err {
|
||||
InvalidOffsetArrayLength => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Length of row offset array is not equal to nrows + 1.",
|
||||
),
|
||||
InvalidOffsetFirstLast => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"First or last row offset is inconsistent with format specification.",
|
||||
),
|
||||
NonmonotonicOffsets => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Row offsets are not monotonically increasing.",
|
||||
),
|
||||
NonmonotonicMinorIndices => E::from_kind_and_msg(
|
||||
K::InvalidStructure,
|
||||
"Column indices are not monotonically increasing (sorted) within each row.",
|
||||
),
|
||||
MinorIndexOutOfBounds => {
|
||||
E::from_kind_and_msg(K::IndexOutOfBounds, "Column indices are out of bounds.")
|
||||
}
|
||||
PatternDuplicateEntry => {
|
||||
E::from_kind_and_msg(K::DuplicateEntry, "Matrix data contains duplicate entries.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator type for iterating over triplets in a CSR matrix.
|
||||
#[derive(Debug)]
|
||||
pub struct CsrTripletIter<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_iter: Iter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone> CsrTripletIter<'a, T> {
|
||||
/// Adapts the triplet iterator to return owned values.
|
||||
///
|
||||
/// The triplet iterator returns references to the values. This method adapts the iterator
|
||||
/// so that the values are cloned.
|
||||
#[inline]
|
||||
pub fn cloned_values(self) -> impl 'a + Iterator<Item = (usize, usize, T)> {
|
||||
self.map(|(i, j, v)| (i, j, v.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrTripletIter<'a, T> {
|
||||
type Item = (usize, usize, &'a T);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let next_entry = self.pattern_iter.next();
|
||||
let next_value = self.values_iter.next();
|
||||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator type for mutably iterating over triplets in a CSR matrix.
|
||||
#[derive(Debug)]
|
||||
pub struct CsrTripletIterMut<'a, T> {
|
||||
pattern_iter: SparsityPatternIter<'a>,
|
||||
values_mut_iter: IterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrTripletIterMut<'a, T> {
|
||||
type Item = (usize, usize, &'a mut T);
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let next_entry = self.pattern_iter.next();
|
||||
let next_value = self.values_mut_iter.next();
|
||||
|
||||
match (next_entry, next_value) {
|
||||
(Some((i, j)), Some(v)) => Some((i, j, v)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An immutable representation of a row in a CSR matrix.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CsrRow<'a, T> {
|
||||
lane: CsLane<'a, T>,
|
||||
}
|
||||
|
||||
/// A mutable representation of a row in a CSR matrix.
|
||||
///
|
||||
/// Note that only explicitly stored entries can be mutated. The sparsity pattern belonging
|
||||
/// to the row cannot be modified.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct CsrRowMut<'a, T> {
|
||||
lane: CsLaneMut<'a, T>,
|
||||
}
|
||||
|
||||
/// Implement the methods common to both CsrRow and CsrRowMut
|
||||
macro_rules! impl_csr_row_common_methods {
|
||||
($name:ty) => {
|
||||
impl<'a, T> $name {
|
||||
/// The number of global columns in the row.
|
||||
#[inline]
|
||||
pub fn ncols(&self) -> usize {
|
||||
self.lane.minor_dim()
|
||||
}
|
||||
|
||||
/// The number of non-zeros in this row.
|
||||
#[inline]
|
||||
pub fn nnz(&self) -> usize {
|
||||
self.lane.nnz()
|
||||
}
|
||||
|
||||
/// The column indices corresponding to explicitly stored entries in this row.
|
||||
#[inline]
|
||||
pub fn col_indices(&self) -> &[usize] {
|
||||
self.lane.minor_indices()
|
||||
}
|
||||
|
||||
/// The values corresponding to explicitly stored entries in this row.
|
||||
#[inline]
|
||||
pub fn values(&self) -> &[T] {
|
||||
self.lane.values()
|
||||
}
|
||||
|
||||
/// Returns an entry for the given global column index.
|
||||
///
|
||||
/// Each call to this function incurs the cost of a binary search among the explicitly
|
||||
/// stored column entries.
|
||||
#[inline]
|
||||
pub fn get_entry(&self, global_col_index: usize) -> Option<SparseEntry<T>> {
|
||||
self.lane.get_entry(global_col_index)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_csr_row_common_methods!(CsrRow<'a, T>);
|
||||
impl_csr_row_common_methods!(CsrRowMut<'a, T>);
|
||||
|
||||
impl<'a, T> CsrRowMut<'a, T> {
|
||||
/// Mutable access to the values corresponding to explicitly stored entries in this row.
|
||||
#[inline]
|
||||
pub fn values_mut(&mut self) -> &mut [T] {
|
||||
self.lane.values_mut()
|
||||
}
|
||||
|
||||
/// Provides simultaneous access to column indices and mutable values corresponding to the
|
||||
/// explicitly stored entries in this row.
|
||||
///
|
||||
/// This method primarily facilitates low-level access for methods that process data stored
|
||||
/// in CSR format directly.
|
||||
#[inline]
|
||||
pub fn cols_and_values_mut(&mut self) -> (&[usize], &mut [T]) {
|
||||
self.lane.indices_and_values_mut()
|
||||
}
|
||||
|
||||
/// Returns a mutable entry for the given global column index.
|
||||
#[inline]
|
||||
pub fn get_entry_mut(&mut self, global_col_index: usize) -> Option<SparseEntryMut<T>> {
|
||||
self.lane.get_entry_mut(global_col_index)
|
||||
}
|
||||
}
|
||||
|
||||
/// Row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||
pub struct CsrRowIter<'a, T> {
|
||||
lane_iter: CsLaneIter<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrRowIter<'a, T> {
|
||||
type Item = CsrRow<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter.next().map(|lane| CsrRow { lane })
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable row iterator for [CsrMatrix](struct.CsrMatrix.html).
|
||||
pub struct CsrRowIterMut<'a, T> {
|
||||
lane_iter: CsLaneIterMut<'a, T>,
|
||||
}
|
||||
|
||||
impl<'a, T> Iterator for CsrRowIterMut<'a, T>
|
||||
where
|
||||
T: 'a,
|
||||
{
|
||||
type Item = CsrRowMut<'a, T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.lane_iter.next().map(|lane| CsrRowMut { lane })
|
||||
}
|
||||
}
|
|
@ -0,0 +1,373 @@
|
|||
use crate::csc::CscMatrix;
|
||||
use crate::ops::serial::spsolve_csc_lower_triangular;
|
||||
use crate::ops::Op;
|
||||
use crate::pattern::SparsityPattern;
|
||||
use core::{iter, mem};
|
||||
use nalgebra::{DMatrix, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||
use std::fmt::{Display, Formatter};
|
||||
|
||||
/// A symbolic sparse Cholesky factorization of a CSC matrix.
|
||||
///
|
||||
/// The symbolic factorization computes the sparsity pattern of `L`, the Cholesky factor.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct CscSymbolicCholesky {
|
||||
// Pattern of the original matrix that was decomposed
|
||||
m_pattern: SparsityPattern,
|
||||
l_pattern: SparsityPattern,
|
||||
// u in this context is L^T, so that M = L L^T
|
||||
u_pattern: SparsityPattern,
|
||||
}
|
||||
|
||||
impl CscSymbolicCholesky {
|
||||
/// Compute the symbolic factorization for a sparsity pattern belonging to a CSC matrix.
|
||||
///
|
||||
/// The sparsity pattern must be symmetric. However, this is not enforced, and it is the
|
||||
/// responsibility of the user to ensure that this property holds.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the sparsity pattern is not square.
|
||||
pub fn factor(pattern: SparsityPattern) -> Self {
|
||||
assert_eq!(
|
||||
pattern.major_dim(),
|
||||
pattern.minor_dim(),
|
||||
"Major and minor dimensions must be the same (square matrix)."
|
||||
);
|
||||
let (l_pattern, u_pattern) = nonzero_pattern(&pattern);
|
||||
Self {
|
||||
m_pattern: pattern,
|
||||
l_pattern,
|
||||
u_pattern,
|
||||
}
|
||||
}
|
||||
|
||||
/// The pattern of the Cholesky factor `L`.
|
||||
pub fn l_pattern(&self) -> &SparsityPattern {
|
||||
&self.l_pattern
|
||||
}
|
||||
}
|
||||
|
||||
/// A sparse Cholesky factorization `A = L L^T` of a [`CscMatrix`].
|
||||
///
|
||||
/// The factor `L` is a sparse, lower-triangular matrix. See the article on [Wikipedia] for
|
||||
/// more information.
|
||||
///
|
||||
/// The implementation is a port of the `CsCholesky` implementation in `nalgebra`. It is similar
|
||||
/// to Tim Davis' [`CSparse`]. The current implementation performs no fill-in reduction, and can
|
||||
/// therefore be expected to produce much too dense Cholesky factors for many matrices.
|
||||
/// It is therefore not currently recommended to use this implementation for serious projects.
|
||||
///
|
||||
/// [`CSparse`]: https://epubs.siam.org/doi/book/10.1137/1.9780898718881
|
||||
/// [Wikipedia]: https://en.wikipedia.org/wiki/Cholesky_decomposition
|
||||
// TODO: We should probably implement PartialEq/Eq, but in that case we'd probably need a
|
||||
// custom implementation, due to the need to exclude the workspace arrays
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CscCholesky<T> {
|
||||
// Pattern of the original matrix
|
||||
m_pattern: SparsityPattern,
|
||||
l_factor: CscMatrix<T>,
|
||||
u_pattern: SparsityPattern,
|
||||
work_x: Vec<T>,
|
||||
work_c: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
#[non_exhaustive]
|
||||
/// Possible errors produced by the Cholesky factorization.
|
||||
pub enum CholeskyError {
|
||||
/// The matrix is not positive definite.
|
||||
NotPositiveDefinite,
|
||||
}
|
||||
|
||||
impl Display for CholeskyError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Matrix is not positive definite")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for CholeskyError {}
|
||||
|
||||
impl<T: RealField> CscCholesky<T> {
|
||||
/// Computes the numerical Cholesky factorization associated with the given
|
||||
/// symbolic factorization and the provided values.
|
||||
///
|
||||
/// The values correspond to the non-zero values of the CSC matrix for which the
|
||||
/// symbolic factorization was computed.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
|
||||
/// symmetric positive definite.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of values differ from the number of non-zeros of the sparsity pattern
|
||||
/// of the matrix that was symbolically factored.
|
||||
pub fn factor_numerical(
|
||||
symbolic: CscSymbolicCholesky,
|
||||
values: &[T],
|
||||
) -> Result<Self, CholeskyError> {
|
||||
assert_eq!(
|
||||
symbolic.l_pattern.nnz(),
|
||||
symbolic.u_pattern.nnz(),
|
||||
"u is just the transpose of l, so should have the same nnz"
|
||||
);
|
||||
|
||||
let l_nnz = symbolic.l_pattern.nnz();
|
||||
let l_values = vec![T::zero(); l_nnz];
|
||||
let l_factor =
|
||||
CscMatrix::try_from_pattern_and_values(symbolic.l_pattern, l_values).unwrap();
|
||||
|
||||
let (nrows, ncols) = (l_factor.nrows(), l_factor.ncols());
|
||||
|
||||
let mut factorization = CscCholesky {
|
||||
m_pattern: symbolic.m_pattern,
|
||||
l_factor,
|
||||
u_pattern: symbolic.u_pattern,
|
||||
work_x: vec![T::zero(); nrows],
|
||||
// Fill with MAX so that things hopefully totally fail if values are not
|
||||
// overwritten. Might be easier to debug this way
|
||||
work_c: vec![usize::MAX, ncols],
|
||||
};
|
||||
|
||||
factorization.refactor(values)?;
|
||||
Ok(factorization)
|
||||
}
|
||||
|
||||
/// Computes the Cholesky factorization of the provided matrix.
|
||||
///
|
||||
/// The matrix must be symmetric positive definite. Symmetry is not checked, and it is up
|
||||
/// to the user to enforce this property.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
|
||||
/// symmetric positive definite.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the matrix is not square.
|
||||
pub fn factor(matrix: &CscMatrix<T>) -> Result<Self, CholeskyError> {
|
||||
let symbolic = CscSymbolicCholesky::factor(matrix.pattern().clone());
|
||||
Self::factor_numerical(symbolic, matrix.values())
|
||||
}
|
||||
|
||||
/// Re-computes the factorization for a new set of non-zero values.
|
||||
///
|
||||
/// This is useful when the values of a matrix changes, but the sparsity pattern remains
|
||||
/// constant.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the numerical factorization fails. This can occur if the matrix is not
|
||||
/// symmetric positive definite.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of values does not match the number of non-zeros in the sparsity
|
||||
/// pattern.
|
||||
pub fn refactor(&mut self, values: &[T]) -> Result<(), CholeskyError> {
|
||||
self.decompose_left_looking(values)
|
||||
}
|
||||
|
||||
/// Returns a reference to the Cholesky factor `L`.
|
||||
pub fn l(&self) -> &CscMatrix<T> {
|
||||
&self.l_factor
|
||||
}
|
||||
|
||||
/// Returns the Cholesky factor `L`.
|
||||
pub fn take_l(self) -> CscMatrix<T> {
|
||||
self.l_factor
|
||||
}
|
||||
|
||||
/// Perform a numerical left-looking cholesky decomposition of a matrix with the same structure as the
|
||||
/// one used to initialize `self`, but with different non-zero values provided by `values`.
|
||||
fn decompose_left_looking(&mut self, values: &[T]) -> Result<(), CholeskyError> {
|
||||
assert!(
|
||||
values.len() >= self.m_pattern.nnz(),
|
||||
// TODO: Improve error message
|
||||
"The set of values is too small."
|
||||
);
|
||||
|
||||
let n = self.l_factor.nrows();
|
||||
|
||||
// Reset `work_c` to the column pointers of `l`.
|
||||
self.work_c.clear();
|
||||
self.work_c.extend_from_slice(self.l_factor.col_offsets());
|
||||
|
||||
unsafe {
|
||||
for k in 0..n {
|
||||
// Scatter the k-th column of the original matrix with the values provided.
|
||||
let range_begin = *self.m_pattern.major_offsets().get_unchecked(k);
|
||||
let range_end = *self.m_pattern.major_offsets().get_unchecked(k + 1);
|
||||
let range_k = range_begin..range_end;
|
||||
|
||||
*self.work_x.get_unchecked_mut(k) = T::zero();
|
||||
for p in range_k.clone() {
|
||||
let irow = *self.m_pattern.minor_indices().get_unchecked(p);
|
||||
|
||||
if irow >= k {
|
||||
*self.work_x.get_unchecked_mut(irow) = *values.get_unchecked(p);
|
||||
}
|
||||
}
|
||||
|
||||
for &j in self.u_pattern.lane(k) {
|
||||
let factor = -*self
|
||||
.l_factor
|
||||
.values()
|
||||
.get_unchecked(*self.work_c.get_unchecked(j));
|
||||
*self.work_c.get_unchecked_mut(j) += 1;
|
||||
|
||||
if j < k {
|
||||
let col_j = self.l_factor.col(j);
|
||||
let col_j_entries = col_j.row_indices().iter().zip(col_j.values());
|
||||
for (&z, val) in col_j_entries {
|
||||
if z >= k {
|
||||
*self.work_x.get_unchecked_mut(z) += val.inlined_clone() * factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let diag = *self.work_x.get_unchecked(k);
|
||||
|
||||
if diag > T::zero() {
|
||||
let denom = diag.sqrt();
|
||||
|
||||
{
|
||||
let (offsets, _, values) = self.l_factor.csc_data_mut();
|
||||
*values.get_unchecked_mut(*offsets.get_unchecked(k)) = denom;
|
||||
}
|
||||
|
||||
let mut col_k = self.l_factor.col_mut(k);
|
||||
let (col_k_rows, col_k_values) = col_k.rows_and_values_mut();
|
||||
let col_k_entries = col_k_rows.iter().zip(col_k_values);
|
||||
for (&p, val) in col_k_entries {
|
||||
*val = *self.work_x.get_unchecked(p) / denom;
|
||||
*self.work_x.get_unchecked_mut(p) = T::zero();
|
||||
}
|
||||
} else {
|
||||
return Err(CholeskyError::NotPositiveDefinite);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Solves the system `A X = B`, where `X` and `B` are dense matrices.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `B` is not square.
|
||||
pub fn solve<'a>(&'a self, b: impl Into<DMatrixSlice<'a, T>>) -> DMatrix<T> {
|
||||
let b = b.into();
|
||||
let mut output = b.clone_owned();
|
||||
self.solve_mut(&mut output);
|
||||
output
|
||||
}
|
||||
|
||||
/// Solves the system `AX = B`, where `X` and `B` are dense matrices.
|
||||
///
|
||||
/// The result is stored in-place in `b`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `b` is not square.
|
||||
pub fn solve_mut<'a>(&'a self, b: impl Into<DMatrixSliceMut<'a, T>>) {
|
||||
let expect_msg = "If the Cholesky factorization succeeded,\
|
||||
then the triangular solve should never fail";
|
||||
// Solve LY = B
|
||||
let mut y = b.into();
|
||||
spsolve_csc_lower_triangular(Op::NoOp(self.l()), &mut y).expect(expect_msg);
|
||||
|
||||
// Solve L^T X = Y
|
||||
let mut x = y;
|
||||
spsolve_csc_lower_triangular(Op::Transpose(self.l()), &mut x).expect(expect_msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn reach(
|
||||
pattern: &SparsityPattern,
|
||||
j: usize,
|
||||
max_j: usize,
|
||||
tree: &[usize],
|
||||
marks: &mut Vec<bool>,
|
||||
out: &mut Vec<usize>,
|
||||
) {
|
||||
marks.clear();
|
||||
marks.resize(tree.len(), false);
|
||||
|
||||
// TODO: avoid all those allocations.
|
||||
let mut tmp = Vec::new();
|
||||
let mut res = Vec::new();
|
||||
|
||||
for &irow in pattern.lane(j) {
|
||||
let mut curr = irow;
|
||||
while curr != usize::max_value() && curr <= max_j && !marks[curr] {
|
||||
marks[curr] = true;
|
||||
tmp.push(curr);
|
||||
curr = tree[curr];
|
||||
}
|
||||
|
||||
tmp.append(&mut res);
|
||||
mem::swap(&mut tmp, &mut res);
|
||||
}
|
||||
|
||||
res.sort_unstable();
|
||||
|
||||
out.append(&mut res);
|
||||
}
|
||||
|
||||
fn nonzero_pattern(m: &SparsityPattern) -> (SparsityPattern, SparsityPattern) {
|
||||
let etree = elimination_tree(m);
|
||||
// Note: We assume CSC, therefore rows == minor and cols == major
|
||||
let (nrows, ncols) = (m.minor_dim(), m.major_dim());
|
||||
let mut rows = Vec::with_capacity(m.nnz());
|
||||
let mut col_offsets = Vec::with_capacity(ncols + 1);
|
||||
let mut marks = Vec::new();
|
||||
|
||||
// NOTE: the following will actually compute the non-zero pattern of
|
||||
// the transpose of l.
|
||||
col_offsets.push(0);
|
||||
for i in 0..nrows {
|
||||
reach(m, i, i, &etree, &mut marks, &mut rows);
|
||||
col_offsets.push(rows.len());
|
||||
}
|
||||
|
||||
let u_pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(nrows, ncols, col_offsets, rows).unwrap();
|
||||
|
||||
// TODO: Avoid this transpose?
|
||||
let l_pattern = u_pattern.transpose();
|
||||
|
||||
(l_pattern, u_pattern)
|
||||
}
|
||||
|
||||
fn elimination_tree(pattern: &SparsityPattern) -> Vec<usize> {
|
||||
// Note: The pattern is assumed to of a CSC matrix, so the number of rows is
|
||||
// given by the minor dimension
|
||||
let nrows = pattern.minor_dim();
|
||||
let mut forest: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||
let mut ancestor: Vec<_> = iter::repeat(usize::max_value()).take(nrows).collect();
|
||||
|
||||
for k in 0..nrows {
|
||||
for &irow in pattern.lane(k) {
|
||||
let mut i = irow;
|
||||
|
||||
while i < k {
|
||||
let i_ancestor = ancestor[i];
|
||||
ancestor[i] = k;
|
||||
|
||||
if i_ancestor == usize::max_value() {
|
||||
forest[i] = k;
|
||||
break;
|
||||
}
|
||||
|
||||
i = i_ancestor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
forest
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
//! Matrix factorization for sparse matrices.
|
||||
//!
|
||||
//! Currently, the only factorization provided here is the [`CscCholesky`] factorization.
|
||||
mod cholesky;
|
||||
|
||||
pub use cholesky::*;
|
|
@ -0,0 +1,267 @@
|
|||
//! Sparse matrices and algorithms for [nalgebra](https://www.nalgebra.org).
|
||||
//!
|
||||
//! This crate extends `nalgebra` with sparse matrix formats and operations on sparse matrices.
|
||||
//!
|
||||
//! ## Goals
|
||||
//! The long-term goals for this crate are listed below.
|
||||
//!
|
||||
//! - Provide proven sparse matrix formats in an easy-to-use and idiomatic Rust API that
|
||||
//! naturally integrates with `nalgebra`.
|
||||
//! - Provide additional expert-level APIs for fine-grained control over operations.
|
||||
//! - Integrate well with external sparse matrix libraries.
|
||||
//! - Provide native Rust high-performance routines, including parallel matrix operations.
|
||||
//!
|
||||
//! ## Highlighted current features
|
||||
//!
|
||||
//! - [CSR](csr::CsrMatrix), [CSC](csc::CscMatrix) and [COO](coo::CooMatrix) formats, and
|
||||
//! [conversions](`convert`) between them.
|
||||
//! - Common arithmetic operations are implemented. See the [`ops`] module.
|
||||
//! - Sparsity patterns in CSR and CSC matrices are explicitly represented by the
|
||||
//! [SparsityPattern](pattern::SparsityPattern) type, which encodes the invariants of the
|
||||
//! associated index data structures.
|
||||
//! - [proptest strategies](`proptest`) for sparse matrices when the feature
|
||||
//! `proptest-support` is enabled.
|
||||
//! - [matrixcompare support](https://crates.io/crates/matrixcompare) for effortless
|
||||
//! (approximate) comparison of matrices in test code (requires the `compare` feature).
|
||||
//!
|
||||
//! ## Current state
|
||||
//!
|
||||
//! The library is in an early, but usable state. The API has been designed to be extensible,
|
||||
//! but breaking changes will be necessary to implement several planned features. While it is
|
||||
//! backed by an extensive test suite, it has yet to be thoroughly battle-tested in real
|
||||
//! applications. Moreover, the focus so far has been on correctness and API design, with little
|
||||
//! focus on performance. Future improvements will include incremental performance enhancements.
|
||||
//!
|
||||
//! Current limitations:
|
||||
//!
|
||||
//! - Limited or no availability of sparse system solvers.
|
||||
//! - Limited support for complex numbers. Currently only arithmetic operations that do not
|
||||
//! rely on particular properties of complex numbers, such as e.g. conjugation, are
|
||||
//! supported.
|
||||
//! - No integration with external libraries.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! Add the following to your `Cargo.toml` file:
|
||||
//!
|
||||
//! ```toml
|
||||
//! [dependencies]
|
||||
//! nalgebra_sparse = "0.1"
|
||||
//! ```
|
||||
//!
|
||||
//! # Supported matrix formats
|
||||
//!
|
||||
//! | Format | Notes |
|
||||
//! | ------------------------|--------------------------------------------- |
|
||||
//! | [COO](`coo::CooMatrix`) | Well-suited for matrix construction. <br /> Ill-suited for algebraic operations. |
|
||||
//! | [CSR](`csr::CsrMatrix`) | Immutable sparsity pattern, suitable for algebraic operations. <br /> Fast row access. |
|
||||
//! | [CSC](`csc::CscMatrix`) | Immutable sparsity pattern, suitable for algebraic operations. <br /> Fast column access. |
|
||||
//!
|
||||
//! What format is best to use depends on the application. The most common use case for sparse
|
||||
//! matrices in science is the solution of sparse linear systems. Here we can differentiate between
|
||||
//! two common cases:
|
||||
//!
|
||||
//! - Direct solvers. Typically, direct solvers take their input in CSR or CSC format.
|
||||
//! - Iterative solvers. Many iterative solvers require only matrix-vector products,
|
||||
//! for which the CSR or CSC formats are suitable.
|
||||
//!
|
||||
//! The [COO](coo::CooMatrix) format is primarily intended for matrix construction.
|
||||
//! A common pattern is to use COO for construction, before converting to CSR or CSC for use
|
||||
//! in a direct solver or for computing matrix-vector products in an iterative solver.
|
||||
//! Some high-performance applications might also directly manipulate the CSR and/or CSC
|
||||
//! formats.
|
||||
//!
|
||||
//! # Example: COO -> CSR -> matrix-vector product
|
||||
//!
|
||||
//! ```rust
|
||||
//! use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
|
||||
//! use nalgebra::{DMatrix, DVector};
|
||||
//! use matrixcompare::assert_matrix_eq;
|
||||
//!
|
||||
//! // The dense representation of the matrix
|
||||
//! let dense = DMatrix::from_row_slice(3, 3,
|
||||
//! &[1.0, 0.0, 3.0,
|
||||
//! 2.0, 0.0, 1.3,
|
||||
//! 0.0, 0.0, 4.1]);
|
||||
//!
|
||||
//! // Build the equivalent COO representation. We only add the non-zero values
|
||||
//! let mut coo = CooMatrix::new(3, 3);
|
||||
//! // We can add elements in any order. For clarity, we do so in row-major order here.
|
||||
//! coo.push(0, 0, 1.0);
|
||||
//! coo.push(0, 2, 3.0);
|
||||
//! coo.push(1, 0, 2.0);
|
||||
//! coo.push(1, 2, 1.3);
|
||||
//! coo.push(2, 2, 4.1);
|
||||
//!
|
||||
//! // The simplest way to construct a CSR matrix is to first construct a COO matrix, and
|
||||
//! // then convert it to CSR. The `From` trait is implemented for conversions between different
|
||||
//! // sparse matrix types.
|
||||
//! // Alternatively, we can construct a matrix directly from the CSR data.
|
||||
//! // See the docs for CsrMatrix for how to do that.
|
||||
//! let csr = CsrMatrix::from(&coo);
|
||||
//!
|
||||
//! // Let's check that the CSR matrix and the dense matrix represent the same matrix.
|
||||
//! // We can use macros from the `matrixcompare` crate to easily do this, despite the fact that
|
||||
//! // we're comparing across two different matrix formats. Note that these macros are only really
|
||||
//! // appropriate for writing tests, however.
|
||||
//! assert_matrix_eq!(csr, dense);
|
||||
//!
|
||||
//! let x = DVector::from_column_slice(&[1.3, -4.0, 3.5]);
|
||||
//!
|
||||
//! // Compute the matrix-vector product y = A * x. We don't need to specify the type here,
|
||||
//! // but let's just do it to make sure we get what we expect
|
||||
//! let y: DVector<_> = &csr * &x;
|
||||
//!
|
||||
//! // Verify the result with a small element-wise absolute tolerance
|
||||
//! let y_expected = DVector::from_column_slice(&[11.8, 7.15, 14.35]);
|
||||
//! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9);
|
||||
//!
|
||||
//! // The above expression is simple, and gives easy to read code, but if we're doing this in a
|
||||
//! // loop, we'll have to keep allocating new vectors. If we determine that this is a bottleneck,
|
||||
//! // then we can resort to the lower level APIs for more control over the operations
|
||||
//! {
|
||||
//! use nalgebra_sparse::ops::{Op, serial::spmm_csr_dense};
|
||||
//! let mut y = y;
|
||||
//! // Compute y <- 0.0 * y + 1.0 * csr * dense. We store the result directly in `y`, without
|
||||
//! // any intermediate allocations
|
||||
//! spmm_csr_dense(0.0, &mut y, 1.0, Op::NoOp(&csr), Op::NoOp(&x));
|
||||
//! assert_matrix_eq!(y, y_expected, comp = abs, tol = 1e-9);
|
||||
//! }
|
||||
//! ```
|
||||
#![deny(non_camel_case_types)]
|
||||
#![deny(unused_parens)]
|
||||
#![deny(non_upper_case_globals)]
|
||||
#![deny(unused_qualifications)]
|
||||
#![deny(unused_results)]
|
||||
#![deny(missing_docs)]
|
||||
|
||||
pub extern crate nalgebra as na;
|
||||
pub mod convert;
|
||||
pub mod coo;
|
||||
pub mod csc;
|
||||
pub mod csr;
|
||||
pub mod factorization;
|
||||
pub mod ops;
|
||||
pub mod pattern;
|
||||
|
||||
pub(crate) mod cs;
|
||||
|
||||
#[cfg(feature = "proptest-support")]
|
||||
pub mod proptest;
|
||||
|
||||
#[cfg(feature = "compare")]
|
||||
mod matrixcompare;
|
||||
|
||||
use num_traits::Zero;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
pub use self::coo::CooMatrix;
|
||||
pub use self::csc::CscMatrix;
|
||||
pub use self::csr::CsrMatrix;
|
||||
|
||||
/// Errors produced by functions that expect well-formed sparse format data.
|
||||
#[derive(Debug)]
|
||||
pub struct SparseFormatError {
|
||||
kind: SparseFormatErrorKind,
|
||||
// Currently we only use an underlying error for generating the `Display` impl
|
||||
error: Box<dyn Error>,
|
||||
}
|
||||
|
||||
impl SparseFormatError {
|
||||
/// The type of error.
|
||||
pub fn kind(&self) -> &SparseFormatErrorKind {
|
||||
&self.kind
|
||||
}
|
||||
|
||||
pub(crate) fn from_kind_and_error(kind: SparseFormatErrorKind, error: Box<dyn Error>) -> Self {
|
||||
Self { kind, error }
|
||||
}
|
||||
|
||||
/// Helper functionality for more conveniently creating errors.
|
||||
pub(crate) fn from_kind_and_msg(kind: SparseFormatErrorKind, msg: &'static str) -> Self {
|
||||
Self::from_kind_and_error(kind, Box::<dyn Error>::from(msg))
|
||||
}
|
||||
}
|
||||
|
||||
/// The type of format error described by a [SparseFormatError](struct.SparseFormatError.html).
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum SparseFormatErrorKind {
|
||||
/// Indicates that the index data associated with the format contains at least one index
|
||||
/// out of bounds.
|
||||
IndexOutOfBounds,
|
||||
|
||||
/// Indicates that the provided data contains at least one duplicate entry, and the
|
||||
/// current format does not support duplicate entries.
|
||||
DuplicateEntry,
|
||||
|
||||
/// Indicates that the provided data for the format does not conform to the high-level
|
||||
/// structure of the format.
|
||||
///
|
||||
/// For example, the arrays defining the format data might have incompatible sizes.
|
||||
InvalidStructure,
|
||||
}
|
||||
|
||||
impl fmt::Display for SparseFormatError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.error)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for SparseFormatError {}
|
||||
|
||||
/// An entry in a sparse matrix.
|
||||
///
|
||||
/// Sparse matrices do not store all their entries explicitly. Therefore, entry (i, j) in the matrix
|
||||
/// can either be a reference to an explicitly stored element, or it is implicitly zero.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum SparseEntry<'a, T> {
|
||||
/// The entry is a reference to an explicitly stored element.
|
||||
///
|
||||
/// Note that the naming here is a misnomer: The element can still be zero, even though it
|
||||
/// is explicitly stored (a so-called "explicit zero").
|
||||
NonZero(&'a T),
|
||||
/// The entry is implicitly zero, i.e. it is not explicitly stored.
|
||||
Zero,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone + Zero> SparseEntry<'a, T> {
|
||||
/// Returns the value represented by this entry.
|
||||
///
|
||||
/// Either clones the underlying reference or returns zero if the entry is not explicitly
|
||||
/// stored.
|
||||
pub fn into_value(self) -> T {
|
||||
match self {
|
||||
SparseEntry::NonZero(value) => value.clone(),
|
||||
SparseEntry::Zero => T::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable entry in a sparse matrix.
|
||||
///
|
||||
/// See also `SparseEntry`.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum SparseEntryMut<'a, T> {
|
||||
/// The entry is a mutable reference to an explicitly stored element.
|
||||
///
|
||||
/// Note that the naming here is a misnomer: The element can still be zero, even though it
|
||||
/// is explicitly stored (a so-called "explicit zero").
|
||||
NonZero(&'a mut T),
|
||||
/// The entry is implicitly zero i.e. it is not explicitly stored.
|
||||
Zero,
|
||||
}
|
||||
|
||||
impl<'a, T: Clone + Zero> SparseEntryMut<'a, T> {
|
||||
/// Returns the value represented by this entry.
|
||||
///
|
||||
/// Either clones the underlying reference or returns zero if the entry is not explicitly
|
||||
/// stored.
|
||||
pub fn into_value(self) -> T {
|
||||
match self {
|
||||
SparseEntryMut::NonZero(value) => value.clone(),
|
||||
SparseEntryMut::Zero => T::zero(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
//! Implements core traits for use with `matrixcompare`.
|
||||
use crate::coo::CooMatrix;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
use matrixcompare_core;
|
||||
use matrixcompare_core::{Access, SparseAccess};
|
||||
|
||||
macro_rules! impl_matrix_for_csr_csc {
|
||||
($MatrixType:ident) => {
|
||||
impl<T: Clone> SparseAccess<T> for $MatrixType<T> {
|
||||
fn nnz(&self) -> usize {
|
||||
$MatrixType::nnz(self)
|
||||
}
|
||||
|
||||
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||
self.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> matrixcompare_core::Matrix<T> for $MatrixType<T> {
|
||||
fn rows(&self) -> usize {
|
||||
self.nrows()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.ncols()
|
||||
}
|
||||
|
||||
fn access(&self) -> Access<T> {
|
||||
Access::Sparse(self)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_matrix_for_csr_csc!(CsrMatrix);
|
||||
impl_matrix_for_csr_csc!(CscMatrix);
|
||||
|
||||
impl<T: Clone> SparseAccess<T> for CooMatrix<T> {
|
||||
fn nnz(&self) -> usize {
|
||||
CooMatrix::nnz(self)
|
||||
}
|
||||
|
||||
fn fetch_triplets(&self) -> Vec<(usize, usize, T)> {
|
||||
self.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, v.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone> matrixcompare_core::Matrix<T> for CooMatrix<T> {
|
||||
fn rows(&self) -> usize {
|
||||
self.nrows()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.ncols()
|
||||
}
|
||||
|
||||
fn access(&self) -> Access<T> {
|
||||
Access::Sparse(self)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,331 @@
|
|||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
|
||||
use crate::ops::serial::{
|
||||
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
|
||||
spmm_csc_prealloc, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc,
|
||||
};
|
||||
use crate::ops::Op;
|
||||
use nalgebra::allocator::Allocator;
|
||||
use nalgebra::base::storage::Storage;
|
||||
use nalgebra::constraint::{DimEq, ShapeConstraint};
|
||||
use nalgebra::{
|
||||
ClosedAdd, ClosedDiv, ClosedMul, ClosedSub, DefaultAllocator, Dim, Dynamic, Matrix, MatrixMN,
|
||||
Scalar, U1,
|
||||
};
|
||||
use num_traits::{One, Zero};
|
||||
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
|
||||
|
||||
/// Helper macro for implementing binary operators for different matrix types
|
||||
/// See below for usage.
|
||||
macro_rules! impl_bin_op {
|
||||
($trait:ident, $method:ident,
|
||||
<$($life:lifetime),* $(,)? $($scalar_type:ident $(: $bounds:path)?)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
|
||||
=>
|
||||
{
|
||||
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
|
||||
where
|
||||
// Note: The Neg bound is currently required because we delegate e.g.
|
||||
// Sub to SpAdd with negative coefficients. This is not well-defined for
|
||||
// unsigned data types.
|
||||
$($scalar_type: $($bounds + )? Scalar + ClosedAdd + ClosedSub + ClosedMul + Zero + One + Neg<Output=T>)?
|
||||
{
|
||||
type Output = $ret;
|
||||
fn $method(self, $b: $b_type) -> Self::Output {
|
||||
let $a = self;
|
||||
$body
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Implements a +/- b for all combinations of reference and owned matrices, for
|
||||
/// CsrMatrix or CscMatrix.
|
||||
macro_rules! impl_sp_plus_minus {
|
||||
// We first match on some special-case syntax, and forward to the actual implementation
|
||||
($matrix_type:ident, $spadd_fn:ident, +) => {
|
||||
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
|
||||
};
|
||||
($matrix_type:ident, $spadd_fn:ident, -) => {
|
||||
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
|
||||
};
|
||||
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
// If both matrices have the same pattern, then we can immediately re-use it
|
||||
let pattern = spadd_pattern(a.pattern(), b.pattern());
|
||||
let values = vec![T::zero(); pattern.nnz()];
|
||||
// We are giving data that is valid by definition, so it is safe to unwrap below
|
||||
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||
.unwrap();
|
||||
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
|
||||
$spadd_fn(T::one(), &mut result, $factor * T::one(), Op::NoOp(&b)).unwrap();
|
||||
result
|
||||
});
|
||||
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
&a $sign b
|
||||
});
|
||||
|
||||
impl_bin_op!($trait, $method,
|
||||
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
a $sign &b
|
||||
});
|
||||
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
|
||||
a $sign &b
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
|
||||
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
|
||||
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
|
||||
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
|
||||
|
||||
macro_rules! impl_mul {
|
||||
($($args:tt)*) => {
|
||||
impl_bin_op!(Mul, mul, $($args)*);
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements a + b for all combinations of reference and owned matrices, for
|
||||
/// CsrMatrix or CscMatrix.
|
||||
macro_rules! impl_spmm {
|
||||
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
|
||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
|
||||
let pattern = $pattern_fn(a.pattern(), b.pattern());
|
||||
let values = vec![T::zero(); pattern.nnz()];
|
||||
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
|
||||
.unwrap();
|
||||
$spmm_fn(T::zero(),
|
||||
&mut result,
|
||||
T::one(),
|
||||
Op::NoOp(a),
|
||||
Op::NoOp(b))
|
||||
.expect("Internal error: spmm failed (please debug).");
|
||||
result
|
||||
});
|
||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
|
||||
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
|
||||
impl_mul!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
|
||||
}
|
||||
}
|
||||
|
||||
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc);
|
||||
// Need to switch order of operations for CSC pattern
|
||||
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc);
|
||||
|
||||
/// Implements Scalar * Matrix operations for *concrete* scalar types. The reason this is necessary
|
||||
/// is that we are not able to implement Mul<Matrix<T>> for all T generically due to orphan rules.
|
||||
macro_rules! impl_concrete_scalar_matrix_mul {
|
||||
($matrix_type:ident, $($scalar_type:ty),*) => {
|
||||
// For each concrete scalar type, forward the implementation of scalar * matrix
|
||||
// to matrix * scalar, which we have already implemented through generics
|
||||
$(
|
||||
impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>)
|
||||
-> $matrix_type<$scalar_type> { b * a });
|
||||
impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>)
|
||||
-> $matrix_type<$scalar_type> { b * a });
|
||||
impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>)
|
||||
-> $matrix_type<$scalar_type> { b * (*a) });
|
||||
impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>)
|
||||
-> $matrix_type<$scalar_type> { b * *a });
|
||||
)*
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements multiplication between matrix and scalar for various matrix types
|
||||
macro_rules! impl_scalar_mul {
|
||||
($matrix_type: ident) => {
|
||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
|
||||
let values: Vec<_> = a.values()
|
||||
.iter()
|
||||
.map(|v_i| v_i.inlined_clone() * b.inlined_clone())
|
||||
.collect();
|
||||
$matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
|
||||
});
|
||||
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||
a * &b
|
||||
});
|
||||
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
|
||||
let mut a = a;
|
||||
for value in a.values_mut() {
|
||||
*value = b.inlined_clone() * value.inlined_clone();
|
||||
}
|
||||
a
|
||||
});
|
||||
impl_mul!(<T>(a: $matrix_type<T>, b: T) -> $matrix_type<T> {
|
||||
a * &b
|
||||
});
|
||||
impl_concrete_scalar_matrix_mul!(
|
||||
$matrix_type,
|
||||
i8, i16, i32, i64, isize, f32, f64);
|
||||
|
||||
impl<T> MulAssign<T> for $matrix_type<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
fn mul_assign(&mut self, scalar: T) {
|
||||
for val in self.values_mut() {
|
||||
*val *= scalar.inlined_clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> MulAssign<&'a T> for $matrix_type<T>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One
|
||||
{
|
||||
fn mul_assign(&mut self, scalar: &'a T) {
|
||||
for val in self.values_mut() {
|
||||
*val *= scalar.inlined_clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_scalar_mul!(CsrMatrix);
|
||||
impl_scalar_mul!(CscMatrix);
|
||||
|
||||
macro_rules! impl_neg {
|
||||
($matrix_type:ident) => {
|
||||
impl<T> Neg for $matrix_type<T>
|
||||
where
|
||||
T: Scalar + Neg<Output = T>,
|
||||
{
|
||||
type Output = $matrix_type<T>;
|
||||
|
||||
fn neg(mut self) -> Self::Output {
|
||||
for v_i in self.values_mut() {
|
||||
*v_i = -v_i.inlined_clone();
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Neg for &'a $matrix_type<T>
|
||||
where
|
||||
T: Scalar + Neg<Output = T>,
|
||||
{
|
||||
type Output = $matrix_type<T>;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
// TODO: This is inefficient. Ideally we'd have a method that would let us
|
||||
// obtain both the sparsity pattern and values from the matrix,
|
||||
// and then modify the values before creating a new matrix from the pattern
|
||||
// and negated values.
|
||||
-self.clone()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_neg!(CsrMatrix);
|
||||
impl_neg!(CscMatrix);
|
||||
|
||||
macro_rules! impl_div {
|
||||
($matrix_type:ident) => {
|
||||
impl_bin_op!(Div, div, <T: ClosedDiv>(matrix: $matrix_type<T>, scalar: T) -> $matrix_type<T> {
|
||||
let mut matrix = matrix;
|
||||
matrix /= scalar;
|
||||
matrix
|
||||
});
|
||||
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: $matrix_type<T>, scalar: &T) -> $matrix_type<T> {
|
||||
matrix / scalar.inlined_clone()
|
||||
});
|
||||
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: T) -> $matrix_type<T> {
|
||||
let new_values = matrix.values()
|
||||
.iter()
|
||||
.map(|v_i| v_i.inlined_clone() / scalar.inlined_clone())
|
||||
.collect();
|
||||
$matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
|
||||
.unwrap()
|
||||
});
|
||||
impl_bin_op!(Div, div, <'a, T: ClosedDiv>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
|
||||
matrix / scalar.inlined_clone()
|
||||
});
|
||||
|
||||
impl<T> DivAssign<T> for $matrix_type<T>
|
||||
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
|
||||
{
|
||||
fn div_assign(&mut self, scalar: T) {
|
||||
self.values_mut().iter_mut().for_each(|v_i| *v_i /= scalar.inlined_clone());
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> DivAssign<&'a T> for $matrix_type<T>
|
||||
where T : Scalar + ClosedAdd + ClosedMul + ClosedDiv + Zero + One
|
||||
{
|
||||
fn div_assign(&mut self, scalar: &'a T) {
|
||||
*self /= scalar.inlined_clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_div!(CsrMatrix);
|
||||
impl_div!(CscMatrix);
|
||||
|
||||
macro_rules! impl_spmm_cs_dense {
|
||||
($matrix_type_name:ident, $spmm_fn:ident) => {
|
||||
// Implement ref-ref
|
||||
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||
let (_, ncols) = rhs.data.shape();
|
||||
let nrows = Dynamic::new(lhs.nrows());
|
||||
let mut result = MatrixMN::<T, Dynamic, C>::zeros_generic(nrows, ncols);
|
||||
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs));
|
||||
result
|
||||
});
|
||||
|
||||
// Implement the other combinations by deferring to ref-ref
|
||||
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||
lhs * &rhs
|
||||
});
|
||||
impl_spmm_cs_dense!($matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||
&lhs * rhs
|
||||
});
|
||||
impl_spmm_cs_dense!($matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
|
||||
&lhs * &rhs
|
||||
});
|
||||
};
|
||||
|
||||
// Main body of the macro. The first pattern just forwards to this pattern but with
|
||||
// different arguments
|
||||
($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident,
|
||||
|$lhs:ident, $rhs:ident| $body:tt) =>
|
||||
{
|
||||
impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type
|
||||
where
|
||||
T: Scalar + ClosedMul + ClosedAdd + ClosedSub + ClosedDiv + Neg + Zero + One,
|
||||
R: Dim,
|
||||
C: Dim,
|
||||
S: Storage<T, R, C>,
|
||||
DefaultAllocator: Allocator<T, Dynamic, C>,
|
||||
// TODO: Is it possible to simplify these bounds?
|
||||
ShapeConstraint:
|
||||
// Bounds so that we can turn MatrixMN<T, Dynamic, C> into a DMatrixSliceMut
|
||||
DimEq<U1, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::RStride>
|
||||
+ DimEq<C, Dynamic>
|
||||
+ DimEq<Dynamic, <<DefaultAllocator as Allocator<T, Dynamic, C>>::Buffer as Storage<T, Dynamic, C>>::CStride>
|
||||
// Bounds so that we can turn &Matrix<T, R, C, S> into a DMatrixSlice
|
||||
+ DimEq<U1, S::RStride>
|
||||
+ DimEq<R, Dynamic>
|
||||
+ DimEq<Dynamic, S::CStride>
|
||||
{
|
||||
// We need the column dimension to be generic, so that if RHS is a vector, then
|
||||
// we also get a vector (and not a matrix)
|
||||
type Output = MatrixMN<T, Dynamic, C>;
|
||||
|
||||
fn mul(self, rhs: $dense_matrix_type) -> Self::Output {
|
||||
let $lhs = self;
|
||||
let $rhs = rhs;
|
||||
$body
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
|
||||
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);
|
|
@ -0,0 +1,194 @@
|
|||
//! Sparse matrix arithmetic operations.
|
||||
//!
|
||||
//! This module contains a number of routines for sparse matrix arithmetic. These routines are
|
||||
//! primarily intended for "expert usage". Most users should prefer to use standard
|
||||
//! `std::ops` operations for simple and readable code when possible. The routines provided here
|
||||
//! offer more control over allocation, and allow fusing some low-level operations for higher
|
||||
//! performance.
|
||||
//!
|
||||
//! The available operations are organized by backend. Currently, only the [`serial`] backend
|
||||
//! is available. In the future, backends that expose parallel operations may become available.
|
||||
//! All `std::ops` implementations will remain single-threaded and powered by the
|
||||
//! `serial` backend.
|
||||
//!
|
||||
//! Many routines are able to implicitly transpose matrices involved in the operation.
|
||||
//! For example, the routine [`spadd_csr_prealloc`](serial::spadd_csr_prealloc) performs the
|
||||
//! operation `C <- beta * C + alpha * op(A)`. Here `op(A)` indicates that the matrix `A` can
|
||||
//! either be used as-is or transposed. The notation `op(A)` is represented in code by the
|
||||
//! [`Op`] enum.
|
||||
//!
|
||||
//! # Available `std::ops` implementations
|
||||
//!
|
||||
//! ## Binary operators
|
||||
//!
|
||||
//! The below table summarizes the currently supported binary operators between matrices.
|
||||
//! In general, binary operators between sparse matrices are only supported if both matrices
|
||||
//! are stored in the same format. All supported binary operators are implemented for
|
||||
//! all four combinations of values and references.
|
||||
//!
|
||||
//! <table>
|
||||
//! <tr>
|
||||
//! <th>LHS (down) \ RHS (right)</th>
|
||||
//! <th>COO</th>
|
||||
//! <th>CSR</th>
|
||||
//! <th>CSC</th>
|
||||
//! <th>Dense</th>
|
||||
//! </tr>
|
||||
//! <tr>
|
||||
//! <th>COO</th>
|
||||
//! <td></td>
|
||||
//! <td></td>
|
||||
//! <td></td>
|
||||
//! <td></td>
|
||||
//! </tr>
|
||||
//! <tr>
|
||||
//! <th>CSR</th>
|
||||
//! <td></td>
|
||||
//! <td>+ - *</td>
|
||||
//! <td></td>
|
||||
//! <td>*</td>
|
||||
//! </tr>
|
||||
//! <tr>
|
||||
//! <th>CSC</th>
|
||||
//! <td></td>
|
||||
//! <td></td>
|
||||
//! <td>+ - *</td>
|
||||
//! <td>*</td>
|
||||
//! </tr>
|
||||
//! <tr>
|
||||
//! <th>Dense</th>
|
||||
//! <td></td>
|
||||
//! <td></td>
|
||||
//! <td></td>
|
||||
//! <td>+ - *</td>
|
||||
//! </tr>
|
||||
//! </table>
|
||||
//!
|
||||
//! As can be seen from the table, only `CSR * Dense` and `CSC * Dense` are supported.
|
||||
//! The other way around, i.e. `Dense * CSR` and `Dense * CSC` are not implemented.
|
||||
//!
|
||||
//! Additionally, [CsrMatrix](`crate::csr::CsrMatrix`) and [CooMatrix](`crate::coo::CooMatrix`)
|
||||
//! support multiplication with scalars, in addition to division by a scalar.
|
||||
//! Note that only `Matrix * Scalar` works in a generic context, although `Scalar * Matrix`
|
||||
//! has been implemented for many of the built-in arithmetic types. This is due to a fundamental
|
||||
//! restriction of the Rust type system. Therefore, in generic code you will need to always place
|
||||
//! the matrix on the left-hand side of the multiplication.
|
||||
//!
|
||||
//! ## Unary operators
|
||||
//!
|
||||
//! The following table lists currently supported unary operators.
|
||||
//!
|
||||
//! | Format | AddAssign\<Matrix\> | MulAssign\<Matrix\> | MulAssign\<Scalar\> | Neg |
|
||||
//! | -------- | ----------------- | ----------------- | ------------------- | ------ |
|
||||
//! | COO | | | | |
|
||||
//! | CSR | | | x | x |
|
||||
//! | CSC | | | x | x |
|
||||
//! |
|
||||
//! # Example usage
|
||||
//!
|
||||
//! For example, consider the case where you want to compute the expression
|
||||
//! `C <- 3.0 * C + 2.0 * A^T * B`, where `A`, `B`, `C` are matrices and `A^T` is the transpose
|
||||
//! of `A`. The simplest way to write this is:
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use nalgebra_sparse::csr::CsrMatrix;
|
||||
//! # let a = CsrMatrix::identity(10); let b = CsrMatrix::identity(10);
|
||||
//! # let mut c = CsrMatrix::identity(10);
|
||||
//! c = 3.0 * c + 2.0 * a.transpose() * b;
|
||||
//! ```
|
||||
//! This is simple and straightforward to read, and therefore the recommended way to implement
|
||||
//! it. However, if you have determined that this is a performance bottleneck of your application,
|
||||
//! it may be possible to speed things up. First, let's see what's going on here. The `std`
|
||||
//! operations are evaluated eagerly. This means that the following steps take place:
|
||||
//!
|
||||
//! 1. Evaluate `let c_temp = 3.0 * c`. This requires scaling all values of the matrix.
|
||||
//! 2. Evaluate `let a_t = a.transpose()` into a new temporary matrix.
|
||||
//! 3. Evaluate `let a_t_b = a_t * b` into a new temporary matrix.
|
||||
//! 4. Evaluate `let a_t_b_scaled = 2.0 * a_t_b`. This requires scaling all values of the matrix.
|
||||
//! 5. Evaluate `c = c_temp + a_t_b_scaled`.
|
||||
//!
|
||||
//! An alternative way to implement this expression (here using CSR matrices) is:
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use nalgebra_sparse::csr::CsrMatrix;
|
||||
//! # let a = CsrMatrix::identity(10); let b = CsrMatrix::identity(10);
|
||||
//! # let mut c = CsrMatrix::identity(10);
|
||||
//! use nalgebra_sparse::ops::{Op, serial::spmm_csr_prealloc};
|
||||
//!
|
||||
//! // Evaluate the expression `c <- 3.0 * c + 2.0 * a^T * b
|
||||
//! spmm_csr_prealloc(3.0, &mut c, 2.0, Op::Transpose(&a), Op::NoOp(&b))
|
||||
//! .expect("We assume that the pattern of C is able to accommodate the result.");
|
||||
//! ```
|
||||
//! Compared to the simpler example, this snippet is harder to read, but it calls a single
|
||||
//! computational kernel that avoids many of the intermediate steps listed out before. Therefore
|
||||
//! directly calling kernels may sometimes lead to better performance. However, this should
|
||||
//! always be verified by performance profiling!
|
||||
|
||||
mod impl_std_ops;
|
||||
pub mod serial;
|
||||
|
||||
/// Determines whether a matrix should be transposed in a given operation.
|
||||
///
|
||||
/// See the [module-level documentation](crate::ops) for the purpose of this enum.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Op<T> {
|
||||
/// Indicates that the matrix should be used as-is.
|
||||
NoOp(T),
|
||||
/// Indicates that the matrix should be transposed.
|
||||
Transpose(T),
|
||||
}
|
||||
|
||||
impl<T> Op<T> {
|
||||
/// Returns a reference to the inner value that the operation applies to.
|
||||
pub fn inner_ref(&self) -> &T {
|
||||
self.as_ref().into_inner()
|
||||
}
|
||||
|
||||
/// Returns an `Op` applied to a reference of the inner value.
|
||||
pub fn as_ref(&self) -> Op<&T> {
|
||||
match self {
|
||||
Op::NoOp(obj) => Op::NoOp(&obj),
|
||||
Op::Transpose(obj) => Op::Transpose(&obj),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the underlying data type.
|
||||
pub fn convert<U>(self) -> Op<U>
|
||||
where
|
||||
T: Into<U>,
|
||||
{
|
||||
self.map_same_op(T::into)
|
||||
}
|
||||
|
||||
/// Transforms the inner value with the provided function, but preserves the operation.
|
||||
pub fn map_same_op<U, F: FnOnce(T) -> U>(self, f: F) -> Op<U> {
|
||||
match self {
|
||||
Op::NoOp(obj) => Op::NoOp(f(obj)),
|
||||
Op::Transpose(obj) => Op::Transpose(f(obj)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the `Op` and returns the inner value.
|
||||
pub fn into_inner(self) -> T {
|
||||
match self {
|
||||
Op::NoOp(obj) | Op::Transpose(obj) => obj,
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the transpose operation.
|
||||
///
|
||||
/// This operation follows the usual semantics of transposition. In particular, double
|
||||
/// transposition is equivalent to no transposition.
|
||||
pub fn transposed(self) -> Self {
|
||||
match self {
|
||||
Op::NoOp(obj) => Op::Transpose(obj),
|
||||
Op::Transpose(obj) => Op::NoOp(obj),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Op<T> {
|
||||
fn from(obj: T) -> Self {
|
||||
Self::NoOp(obj)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
use crate::cs::CsMatrix;
|
||||
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||
use crate::ops::Op;
|
||||
use crate::SparseEntryMut;
|
||||
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
fn spmm_cs_unexpected_entry() -> OperationError {
|
||||
OperationError::from_kind_and_message(
|
||||
OperationErrorKind::InvalidPattern,
|
||||
String::from("Found unexpected entry that is not present in `c`."),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||
///
|
||||
/// Since CSR/CSC matrices are basically transpositions of each other, which lets us use the same
|
||||
/// algorithm for the SPMM implementation. The implementation here is written in a CSR-centric
|
||||
/// manner. This means that when using it for CSC, the order of the matrices needs to be
|
||||
/// reversed (since transpose(AB) = transpose(B) * transpose(A) and CSC(A) = transpose(CSR(A)).
|
||||
///
|
||||
/// We assume here that the matrices have already been verified to be dimensionally compatible.
|
||||
pub fn spmm_cs_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CsMatrix<T>,
|
||||
alpha: T,
|
||||
a: &CsMatrix<T>,
|
||||
b: &CsMatrix<T>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
for i in 0..c.pattern().major_dim() {
|
||||
let a_lane_i = a.get_lane(i).unwrap();
|
||||
let mut c_lane_i = c.get_lane_mut(i).unwrap();
|
||||
for c_ij in c_lane_i.values_mut() {
|
||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone();
|
||||
}
|
||||
|
||||
for (&k, a_ik) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||
let b_lane_k = b.get_lane(k).unwrap();
|
||||
let (mut c_lane_i_cols, mut c_lane_i_values) = c_lane_i.indices_and_values_mut();
|
||||
let alpha_aik = alpha.inlined_clone() * a_ik.inlined_clone();
|
||||
for (j, b_kj) in b_lane_k.minor_indices().iter().zip(b_lane_k.values()) {
|
||||
// Determine the location in C to append the value
|
||||
let (c_local_idx, _) = c_lane_i_cols
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, c_col)| *c_col == j)
|
||||
.ok_or_else(spmm_cs_unexpected_entry)?;
|
||||
|
||||
c_lane_i_values[c_local_idx] += alpha_aik.inlined_clone() * b_kj.inlined_clone();
|
||||
c_lane_i_cols = &c_lane_i_cols[c_local_idx..];
|
||||
c_lane_i_values = &mut c_lane_i_values[c_local_idx..];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spadd_cs_unexpected_entry() -> OperationError {
|
||||
OperationError::from_kind_and_message(
|
||||
OperationErrorKind::InvalidPattern,
|
||||
String::from("Found entry in `op(a)` that is not present in `c`."),
|
||||
)
|
||||
}
|
||||
|
||||
/// Helper functionality for implementing CSR/CSC SPADD.
|
||||
pub fn spadd_cs_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CsMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
match a {
|
||||
Op::NoOp(a) => {
|
||||
for (mut c_lane_i, a_lane_i) in c.lane_iter_mut().zip(a.lane_iter()) {
|
||||
if beta != T::one() {
|
||||
for c_ij in c_lane_i.values_mut() {
|
||||
*c_ij *= beta.inlined_clone();
|
||||
}
|
||||
}
|
||||
|
||||
let (mut c_minors, mut c_vals) = c_lane_i.indices_and_values_mut();
|
||||
let (a_minors, a_vals) = (a_lane_i.minor_indices(), a_lane_i.values());
|
||||
|
||||
for (a_col, a_val) in a_minors.iter().zip(a_vals) {
|
||||
// TODO: Use exponential search instead of linear search.
|
||||
// If C has substantially more entries in the row than A, then a line search
|
||||
// will needlessly visit many entries in C.
|
||||
let (c_idx, _) = c_minors
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, c_col)| *c_col == a_col)
|
||||
.ok_or_else(spadd_cs_unexpected_entry)?;
|
||||
c_vals[c_idx] += alpha.inlined_clone() * a_val.inlined_clone();
|
||||
c_minors = &c_minors[c_idx..];
|
||||
c_vals = &mut c_vals[c_idx..];
|
||||
}
|
||||
}
|
||||
}
|
||||
Op::Transpose(a) => {
|
||||
if beta != T::one() {
|
||||
for c_ij in c.values_mut() {
|
||||
*c_ij *= beta.inlined_clone();
|
||||
}
|
||||
}
|
||||
|
||||
for (i, a_lane_i) in a.lane_iter().enumerate() {
|
||||
for (&j, a_val) in a_lane_i.minor_indices().iter().zip(a_lane_i.values()) {
|
||||
let a_val = a_val.inlined_clone();
|
||||
let alpha = alpha.inlined_clone();
|
||||
match c.get_entry_mut(j, i).unwrap() {
|
||||
SparseEntryMut::NonZero(c_ji) => *c_ji += alpha * a_val,
|
||||
SparseEntryMut::Zero => return Err(spadd_cs_unexpected_entry()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper functionality for implementing CSR/CSC SPMM.
|
||||
///
|
||||
/// The implementation essentially assumes that `a` is a CSR matrix. To use it with CSC matrices,
|
||||
/// the transposed operation must be specified for the CSC matrix.
|
||||
pub fn spmm_cs_dense<T>(
|
||||
beta: T,
|
||||
mut c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
match a {
|
||||
Op::NoOp(a) => {
|
||||
for j in 0..c.ncols() {
|
||||
let mut c_col_j = c.column_mut(j);
|
||||
for (c_ij, a_row_i) in c_col_j.iter_mut().zip(a.lane_iter()) {
|
||||
let mut dot_ij = T::zero();
|
||||
for (&k, a_ik) in a_row_i.minor_indices().iter().zip(a_row_i.values()) {
|
||||
let b_contrib = match b {
|
||||
Op::NoOp(ref b) => b.index((k, j)),
|
||||
Op::Transpose(ref b) => b.index((j, k)),
|
||||
};
|
||||
dot_ij += a_ik.inlined_clone() * b_contrib.inlined_clone();
|
||||
}
|
||||
*c_ij = beta.inlined_clone() * c_ij.inlined_clone()
|
||||
+ alpha.inlined_clone() * dot_ij;
|
||||
}
|
||||
}
|
||||
}
|
||||
Op::Transpose(a) => {
|
||||
// In this case, we have to pre-multiply C by beta
|
||||
c *= beta;
|
||||
|
||||
for k in 0..a.pattern().major_dim() {
|
||||
let a_row_k = a.get_lane(k).unwrap();
|
||||
for (&i, a_ki) in a_row_k.minor_indices().iter().zip(a_row_k.values()) {
|
||||
let gamma_ki = alpha.inlined_clone() * a_ki.inlined_clone();
|
||||
let mut c_row_i = c.row_mut(i);
|
||||
match b {
|
||||
Op::NoOp(ref b) => {
|
||||
let b_row_k = b.row(k);
|
||||
for (c_ij, b_kj) in c_row_i.iter_mut().zip(b_row_k.iter()) {
|
||||
*c_ij += gamma_ki.inlined_clone() * b_kj.inlined_clone();
|
||||
}
|
||||
}
|
||||
Op::Transpose(ref b) => {
|
||||
let b_col_k = b.column(k);
|
||||
for (c_ij, b_jk) in c_row_i.iter_mut().zip(b_col_k.iter()) {
|
||||
*c_ij += gamma_ki.inlined_clone() * b_jk.inlined_clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,255 @@
|
|||
use crate::csc::CscMatrix;
|
||||
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||
use crate::ops::serial::{OperationError, OperationErrorKind};
|
||||
use crate::ops::Op;
|
||||
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, RealField, Scalar};
|
||||
use num_traits::{One, Zero};
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spmm_csc_dense<'a, T>(
|
||||
beta: T,
|
||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
let b = b.convert();
|
||||
spmm_csc_dense_(beta, c.into(), alpha, a, b)
|
||||
}
|
||||
|
||||
fn spmm_csc_dense_<T>(
|
||||
beta: T,
|
||||
c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
// Need to interpret matrix as transposed since the spmm_cs_dense function assumes CSR layout
|
||||
let a = a.transposed().map_same_op(|a| &a.cs);
|
||||
spmm_cs_dense(beta, c, alpha, a, b)
|
||||
}
|
||||
|
||||
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
||||
///
|
||||
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||
/// returned.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spadd_csc_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CscMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spadd_dims!(c, a);
|
||||
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||
}
|
||||
|
||||
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the sparsity pattern of `C` is not able to store the result of the operation,
|
||||
/// an error is returned.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spmm_csc_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CscMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CscMatrix<T>>,
|
||||
b: Op<&CscMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
|
||||
use Op::{NoOp, Transpose};
|
||||
|
||||
match (&a, &b) {
|
||||
(NoOp(ref a), NoOp(ref b)) => {
|
||||
// Note: We have to reverse the order for CSC matrices
|
||||
spmm_cs_prealloc(beta, &mut c.cs, alpha, &b.cs, &a.cs)
|
||||
}
|
||||
_ => {
|
||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||
// and calling the operation again without transposition
|
||||
let a_ref: &CscMatrix<T> = a.inner_ref();
|
||||
let b_ref: &CscMatrix<T> = b.inner_ref();
|
||||
let (a, b) = {
|
||||
use Cow::*;
|
||||
match (&a, &b) {
|
||||
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||
(Transpose(ref a), Transpose(ref b)) => {
|
||||
(Owned(a.transpose()), Owned(b.transpose()))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
spmm_csc_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Solve the lower triangular system `op(L) X = B`.
|
||||
///
|
||||
/// Only the lower triangular part of L is read, and the result is stored in B.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// An error is returned if the system can not be solved due to the matrix being singular.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `L` is not square, or if `L` and `B` are not dimensionally compatible.
|
||||
pub fn spsolve_csc_lower_triangular<'a, T: RealField>(
|
||||
l: Op<&CscMatrix<T>>,
|
||||
b: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
) -> Result<(), OperationError> {
|
||||
let b = b.into();
|
||||
let l_matrix = l.into_inner();
|
||||
assert_eq!(
|
||||
l_matrix.nrows(),
|
||||
l_matrix.ncols(),
|
||||
"Matrix must be square for triangular solve."
|
||||
);
|
||||
assert_eq!(
|
||||
l_matrix.nrows(),
|
||||
b.nrows(),
|
||||
"Dimension mismatch in sparse lower triangular solver."
|
||||
);
|
||||
match l {
|
||||
Op::NoOp(a) => spsolve_csc_lower_triangular_no_transpose(a, b),
|
||||
Op::Transpose(a) => spsolve_csc_lower_triangular_transpose(a, b),
|
||||
}
|
||||
}
|
||||
|
||||
fn spsolve_csc_lower_triangular_no_transpose<T: RealField>(
|
||||
l: &CscMatrix<T>,
|
||||
b: DMatrixSliceMut<T>,
|
||||
) -> Result<(), OperationError> {
|
||||
let mut x = b;
|
||||
|
||||
// Solve column-by-column
|
||||
for j in 0..x.ncols() {
|
||||
let mut x_col_j = x.column_mut(j);
|
||||
|
||||
for k in 0..l.ncols() {
|
||||
let l_col_k = l.col(k);
|
||||
|
||||
// Skip entries above the diagonal
|
||||
// TODO: Can use exponential search here to quickly skip entries
|
||||
// (we'd like to avoid using binary search as it's very cache unfriendly
|
||||
// and the matrix might actually *be* lower triangular, which would induce
|
||||
// a severe penalty)
|
||||
let diag_csc_index = l_col_k.row_indices().iter().position(|&i| i == k);
|
||||
if let Some(diag_csc_index) = diag_csc_index {
|
||||
let l_kk = l_col_k.values()[diag_csc_index];
|
||||
|
||||
if l_kk != T::zero() {
|
||||
// Update entry associated with diagonal
|
||||
x_col_j[k] /= l_kk;
|
||||
// Copy value after updating (so we don't run into the borrow checker)
|
||||
let x_kj = x_col_j[k];
|
||||
|
||||
let row_indices = &l_col_k.row_indices()[(diag_csc_index + 1)..];
|
||||
let l_values = &l_col_k.values()[(diag_csc_index + 1)..];
|
||||
|
||||
// Note: The remaining entries are below the diagonal
|
||||
for (&i, l_ik) in row_indices.iter().zip(l_values) {
|
||||
let x_ij = &mut x_col_j[i];
|
||||
*x_ij -= l_ik.inlined_clone() * x_kj;
|
||||
}
|
||||
|
||||
x_col_j[k] = x_kj;
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn spsolve_encountered_zero_diagonal() -> Result<(), OperationError> {
|
||||
let message = "Matrix contains at least one diagonal entry that is zero.";
|
||||
Err(OperationError::from_kind_and_message(
|
||||
OperationErrorKind::Singular,
|
||||
String::from(message),
|
||||
))
|
||||
}
|
||||
|
||||
fn spsolve_csc_lower_triangular_transpose<T: RealField>(
|
||||
l: &CscMatrix<T>,
|
||||
b: DMatrixSliceMut<T>,
|
||||
) -> Result<(), OperationError> {
|
||||
let mut x = b;
|
||||
|
||||
// Solve column-by-column
|
||||
for j in 0..x.ncols() {
|
||||
let mut x_col_j = x.column_mut(j);
|
||||
|
||||
// Due to the transposition, we're essentially solving an upper triangular system,
|
||||
// and the columns in our matrix become rows
|
||||
|
||||
for i in (0..l.ncols()).rev() {
|
||||
let l_col_i = l.col(i);
|
||||
|
||||
// Skip entries above the diagonal
|
||||
// TODO: Can use exponential search here to quickly skip entries
|
||||
let diag_csc_index = l_col_i.row_indices().iter().position(|&k| i == k);
|
||||
if let Some(diag_csc_index) = diag_csc_index {
|
||||
let l_ii = l_col_i.values()[diag_csc_index];
|
||||
|
||||
if l_ii != T::zero() {
|
||||
// // Update entry associated with diagonal
|
||||
// x_col_j[k] /= a_kk;
|
||||
|
||||
// Copy value after updating (so we don't run into the borrow checker)
|
||||
let mut x_ii = x_col_j[i];
|
||||
|
||||
let row_indices = &l_col_i.row_indices()[(diag_csc_index + 1)..];
|
||||
let a_values = &l_col_i.values()[(diag_csc_index + 1)..];
|
||||
|
||||
// Note: The remaining entries are below the diagonal
|
||||
for (&k, &l_ki) in row_indices.iter().zip(a_values) {
|
||||
let x_kj = x_col_j[k];
|
||||
x_ii -= l_ki * x_kj;
|
||||
}
|
||||
|
||||
x_col_j[i] = x_ii / l_ii;
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
} else {
|
||||
return spsolve_encountered_zero_diagonal();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
use crate::csr::CsrMatrix;
|
||||
use crate::ops::serial::cs::{spadd_cs_prealloc, spmm_cs_dense, spmm_cs_prealloc};
|
||||
use crate::ops::serial::OperationError;
|
||||
use crate::ops::Op;
|
||||
use nalgebra::{ClosedAdd, ClosedMul, DMatrixSlice, DMatrixSliceMut, Scalar};
|
||||
use num_traits::{One, Zero};
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// Sparse-dense matrix-matrix multiplication `C <- beta * C + alpha * op(A) * op(B)`.
|
||||
pub fn spmm_csr_dense<'a, T>(
|
||||
beta: T,
|
||||
c: impl Into<DMatrixSliceMut<'a, T>>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<impl Into<DMatrixSlice<'a, T>>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
let b = b.convert();
|
||||
spmm_csr_dense_(beta, c.into(), alpha, a, b)
|
||||
}
|
||||
|
||||
fn spmm_csr_dense_<T>(
|
||||
beta: T,
|
||||
c: DMatrixSliceMut<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<DMatrixSlice<T>>,
|
||||
) where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
spmm_cs_dense(beta, c, alpha, a.map_same_op(|a| &a.cs), b)
|
||||
}
|
||||
|
||||
/// Sparse matrix addition `C <- beta * C + alpha * op(A)`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the pattern of `c` does not accommodate all the non-zero entries in `a`, an error is
|
||||
/// returned.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spadd_csr_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CsrMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spadd_dims!(c, a);
|
||||
spadd_cs_prealloc(beta, &mut c.cs, alpha, a.map_same_op(|a| &a.cs))
|
||||
}
|
||||
|
||||
/// Sparse-sparse matrix multiplication, `C <- beta * C + alpha * op(A) * op(B)`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the pattern of `C` is not able to hold the result of the operation, an error is returned.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the dimensions of the matrices involved are not compatible with the expression.
|
||||
pub fn spmm_csr_prealloc<T>(
|
||||
beta: T,
|
||||
c: &mut CsrMatrix<T>,
|
||||
alpha: T,
|
||||
a: Op<&CsrMatrix<T>>,
|
||||
b: Op<&CsrMatrix<T>>,
|
||||
) -> Result<(), OperationError>
|
||||
where
|
||||
T: Scalar + ClosedAdd + ClosedMul + Zero + One,
|
||||
{
|
||||
assert_compatible_spmm_dims!(c, a, b);
|
||||
|
||||
use Op::{NoOp, Transpose};
|
||||
|
||||
match (&a, &b) {
|
||||
(NoOp(ref a), NoOp(ref b)) => spmm_cs_prealloc(beta, &mut c.cs, alpha, &a.cs, &b.cs),
|
||||
_ => {
|
||||
// Currently we handle transposition by explicitly precomputing transposed matrices
|
||||
// and calling the operation again without transposition
|
||||
// TODO: At least use workspaces to allow control of allocations. Maybe
|
||||
// consider implementing certain patterns (like A^T * B) explicitly
|
||||
let a_ref: &CsrMatrix<T> = a.inner_ref();
|
||||
let b_ref: &CsrMatrix<T> = b.inner_ref();
|
||||
let (a, b) = {
|
||||
use Cow::*;
|
||||
match (&a, &b) {
|
||||
(NoOp(_), NoOp(_)) => unreachable!(),
|
||||
(Transpose(ref a), NoOp(_)) => (Owned(a.transpose()), Borrowed(b_ref)),
|
||||
(NoOp(_), Transpose(ref b)) => (Borrowed(a_ref), Owned(b.transpose())),
|
||||
(Transpose(ref a), Transpose(ref b)) => {
|
||||
(Owned(a.transpose()), Owned(b.transpose()))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
spmm_csr_prealloc(beta, c, alpha, NoOp(a.as_ref()), NoOp(b.as_ref()))
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,124 @@
|
|||
//! Serial sparse matrix arithmetic routines.
|
||||
//!
|
||||
//! All routines are single-threaded.
|
||||
//!
|
||||
//! Some operations have the `prealloc` suffix. This means that they expect that the sparsity
|
||||
//! pattern of the output matrix has already been pre-allocated: that is, the pattern of the result
|
||||
//! of the operation fits entirely in the output pattern. In the future, there will also be
|
||||
//! some operations which will be able to dynamically adapt the output pattern to fit the
|
||||
//! result, but these have yet to be implemented.
|
||||
|
||||
#[macro_use]
|
||||
macro_rules! assert_compatible_spmm_dims {
|
||||
($c:expr, $a:expr, $b:expr) => {{
|
||||
use crate::ops::Op::{NoOp, Transpose};
|
||||
match (&$a, &$b) {
|
||||
(NoOp(ref a), NoOp(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||
assert_eq!(a.ncols(), b.nrows(), "A.ncols() != B.nrows()");
|
||||
}
|
||||
(Transpose(ref a), NoOp(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), b.ncols(), "C.ncols() != B.ncols()");
|
||||
assert_eq!(a.nrows(), b.nrows(), "A.nrows() != B.nrows()");
|
||||
}
|
||||
(NoOp(ref a), Transpose(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||
assert_eq!(a.ncols(), b.ncols(), "A.ncols() != B.ncols()");
|
||||
}
|
||||
(Transpose(ref a), Transpose(ref b)) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), b.nrows(), "C.ncols() != B.nrows()");
|
||||
assert_eq!(a.nrows(), b.ncols(), "A.nrows() != B.ncols()");
|
||||
}
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
#[macro_use]
|
||||
macro_rules! assert_compatible_spadd_dims {
|
||||
($c:expr, $a:expr) => {
|
||||
use crate::ops::Op;
|
||||
match $a {
|
||||
Op::NoOp(a) => {
|
||||
assert_eq!($c.nrows(), a.nrows(), "C.nrows() != A.nrows()");
|
||||
assert_eq!($c.ncols(), a.ncols(), "C.ncols() != A.ncols()");
|
||||
}
|
||||
Op::Transpose(a) => {
|
||||
assert_eq!($c.nrows(), a.ncols(), "C.nrows() != A.ncols()");
|
||||
assert_eq!($c.ncols(), a.nrows(), "C.ncols() != A.nrows()");
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mod cs;
|
||||
mod csc;
|
||||
mod csr;
|
||||
mod pattern;
|
||||
|
||||
pub use csc::*;
|
||||
pub use csr::*;
|
||||
pub use pattern::*;
|
||||
use std::fmt;
|
||||
use std::fmt::Formatter;
|
||||
|
||||
/// A description of the error that occurred during an arithmetic operation.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OperationError {
|
||||
error_kind: OperationErrorKind,
|
||||
message: String,
|
||||
}
|
||||
|
||||
/// The different kinds of operation errors that may occur.
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum OperationErrorKind {
|
||||
/// Indicates that one or more sparsity patterns involved in the operation violate the
|
||||
/// expectations of the routine.
|
||||
///
|
||||
/// For example, this could indicate that the sparsity pattern of the output is not able to
|
||||
/// contain the result of the operation.
|
||||
InvalidPattern,
|
||||
|
||||
/// Indicates that a matrix is singular when it is expected to be invertible.
|
||||
Singular,
|
||||
}
|
||||
|
||||
impl OperationError {
|
||||
fn from_kind_and_message(error_type: OperationErrorKind, message: String) -> Self {
|
||||
Self {
|
||||
error_kind: error_type,
|
||||
message,
|
||||
}
|
||||
}
|
||||
|
||||
/// The operation error kind.
|
||||
pub fn kind(&self) -> &OperationErrorKind {
|
||||
&self.error_kind
|
||||
}
|
||||
|
||||
/// The underlying error message.
|
||||
pub fn message(&self) -> &str {
|
||||
self.message.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for OperationError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "Sparse matrix operation error: ")?;
|
||||
match self.kind() {
|
||||
OperationErrorKind::InvalidPattern => {
|
||||
write!(f, "InvalidPattern")?;
|
||||
}
|
||||
OperationErrorKind::Singular => {
|
||||
write!(f, "Singular")?;
|
||||
}
|
||||
}
|
||||
write!(f, " Message: {}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for OperationError {}
|
|
@ -0,0 +1,152 @@
|
|||
use crate::pattern::SparsityPattern;
|
||||
|
||||
use std::iter;
|
||||
|
||||
/// Sparse matrix addition pattern construction, `C <- A + B`.
|
||||
///
|
||||
/// Builds the pattern for `C`, which is able to hold the result of the sum `A + B`.
|
||||
/// The patterns are assumed to have the same major and minor dimensions. In other words,
|
||||
/// both patterns `A` and `B` must both stem from the same kind of compressed matrix:
|
||||
/// CSR or CSC.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the patterns do not have the same major and minor dimensions.
|
||||
pub fn spadd_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
assert_eq!(
|
||||
a.major_dim(),
|
||||
b.major_dim(),
|
||||
"Patterns must have identical major dimensions."
|
||||
);
|
||||
assert_eq!(
|
||||
a.minor_dim(),
|
||||
b.minor_dim(),
|
||||
"Patterns must have identical minor dimensions."
|
||||
);
|
||||
|
||||
let mut offsets = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
offsets.reserve(a.major_dim() + 1);
|
||||
indices.clear();
|
||||
|
||||
offsets.push(0);
|
||||
|
||||
for lane_idx in 0..a.major_dim() {
|
||||
let lane_a = a.lane(lane_idx);
|
||||
let lane_b = b.lane(lane_idx);
|
||||
indices.extend(iterate_union(lane_a, lane_b));
|
||||
offsets.push(indices.len());
|
||||
}
|
||||
|
||||
// TODO: Consider circumventing format checks? (requires unsafe, should benchmark first)
|
||||
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), a.minor_dim(), offsets, indices)
|
||||
.expect("Internal error: Pattern must be valid by definition")
|
||||
}
|
||||
|
||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||
///
|
||||
/// Assumes that the sparsity patterns both represent CSC matrices, and the result is also
|
||||
/// represented as the sparsity pattern of a CSC matrix.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the patterns, when interpreted as CSC patterns, are not compatible for
|
||||
/// matrix multiplication.
|
||||
pub fn spmm_csc_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
// Let C = A * B in CSC format. We note that
|
||||
// C^T = B^T * A^T.
|
||||
// Since the interpretation of a CSC matrix in CSR format represents the transpose of the
|
||||
// matrix in CSR, we can compute C^T in *CSR format* by switching the order of a and b,
|
||||
// which lets us obtain C^T in CSR format. Re-interpreting this as CSC gives us C in CSC format
|
||||
spmm_csr_pattern(b, a)
|
||||
}
|
||||
|
||||
/// Sparse matrix multiplication pattern construction, `C <- A * B`.
|
||||
///
|
||||
/// Assumes that the sparsity patterns both represent CSR matrices, and the result is also
|
||||
/// represented as the sparsity pattern of a CSR matrix.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the patterns, when interpreted as CSR patterns, are not compatible for
|
||||
/// matrix multiplication.
|
||||
pub fn spmm_csr_pattern(a: &SparsityPattern, b: &SparsityPattern) -> SparsityPattern {
|
||||
assert_eq!(
|
||||
a.minor_dim(),
|
||||
b.major_dim(),
|
||||
"a and b must have compatible dimensions"
|
||||
);
|
||||
|
||||
let mut offsets = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
offsets.push(0);
|
||||
|
||||
// Keep a vector of whether we have visited a particular minor index when working
|
||||
// on a major lane
|
||||
// TODO: Consider using a bitvec or similar here to reduce pressure on memory
|
||||
// (would cut memory use to 1/8, which might help reduce cache misses)
|
||||
let mut visited = vec![false; b.minor_dim()];
|
||||
|
||||
for i in 0..a.major_dim() {
|
||||
let a_lane_i = a.lane(i);
|
||||
let c_lane_i_offset = *offsets.last().unwrap();
|
||||
for &k in a_lane_i {
|
||||
let b_lane_k = b.lane(k);
|
||||
|
||||
for &j in b_lane_k {
|
||||
let have_visited_j = &mut visited[j];
|
||||
if !*have_visited_j {
|
||||
indices.push(j);
|
||||
*have_visited_j = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let c_lane_i = &mut indices[c_lane_i_offset..];
|
||||
c_lane_i.sort_unstable();
|
||||
|
||||
// Reset visits so that visited[j] == false for all j for the next major lane
|
||||
for j in c_lane_i {
|
||||
visited[*j] = false;
|
||||
}
|
||||
|
||||
offsets.push(indices.len());
|
||||
}
|
||||
|
||||
SparsityPattern::try_from_offsets_and_indices(a.major_dim(), b.minor_dim(), offsets, indices)
|
||||
.expect("Internal error: Invalid pattern during matrix multiplication pattern construction")
|
||||
}
|
||||
|
||||
/// Iterate over the union of the two sets represented by sorted slices
|
||||
/// (with unique elements)
|
||||
fn iterate_union<'a>(
|
||||
mut sorted_a: &'a [usize],
|
||||
mut sorted_b: &'a [usize],
|
||||
) -> impl Iterator<Item = usize> + 'a {
|
||||
iter::from_fn(move || {
|
||||
if let (Some(a_item), Some(b_item)) = (sorted_a.first(), sorted_b.first()) {
|
||||
let item = if a_item < b_item {
|
||||
sorted_a = &sorted_a[1..];
|
||||
a_item
|
||||
} else if b_item < a_item {
|
||||
sorted_b = &sorted_b[1..];
|
||||
b_item
|
||||
} else {
|
||||
// Both lists contain the same element, advance both slices to avoid
|
||||
// duplicate entries in the result
|
||||
sorted_a = &sorted_a[1..];
|
||||
sorted_b = &sorted_b[1..];
|
||||
a_item
|
||||
};
|
||||
Some(*item)
|
||||
} else if let Some(a_item) = sorted_a.first() {
|
||||
sorted_a = &sorted_a[1..];
|
||||
Some(*a_item)
|
||||
} else if let Some(b_item) = sorted_b.first() {
|
||||
sorted_b = &sorted_b[1..];
|
||||
Some(*b_item)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,393 @@
|
|||
//! Sparsity patterns for CSR and CSC matrices.
|
||||
use crate::cs::transpose_cs;
|
||||
use crate::SparseFormatError;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
/// A representation of the sparsity pattern of a CSR or CSC matrix.
|
||||
///
|
||||
/// CSR and CSC matrices store matrices in a very similar fashion. In fact, in a certain sense,
|
||||
/// they are transposed. More precisely, when reinterpreting the three data arrays of a CSR
|
||||
/// matrix as a CSC matrix, we obtain the CSC representation of its transpose.
|
||||
///
|
||||
/// [`SparsityPattern`] is an abstraction built on this observation. Whereas CSR matrices
|
||||
/// store a matrix row-by-row, and a CSC matrix stores a matrix column-by-column, a
|
||||
/// `SparsityPattern` represents only the index data structure of a matrix *lane-by-lane*.
|
||||
/// Here, a *lane* is a generalization of rows and columns. We further define *major lanes*
|
||||
/// and *minor lanes*. The sparsity pattern of a CSR matrix is then obtained by interpreting
|
||||
/// major/minor as row/column. Conversely, we obtain the sparsity pattern of a CSC matrix by
|
||||
/// interpreting major/minor as column/row.
|
||||
///
|
||||
/// This allows us to use a common abstraction to talk about sparsity patterns of CSR and CSC
|
||||
/// matrices. This is convenient, because at the abstract level, the invariants of the formats
|
||||
/// are the same. Hence we may encode the invariants of the index data structure separately from
|
||||
/// the scalar values of the matrix. This is especially useful in applications where the
|
||||
/// sparsity pattern is built ahead of the matrix values, or the same sparsity pattern is re-used
|
||||
/// between different matrices. Finally, we can use `SparsityPattern` to encode adjacency
|
||||
/// information in graphs.
|
||||
///
|
||||
/// # Format
|
||||
///
|
||||
/// The format is exactly the same as for the index data structures of CSR and CSC matrices.
|
||||
/// This means that the sparsity pattern of an `m x n` sparse matrix with `nnz` non-zeros,
|
||||
/// where in this case `m x n` does *not* mean `rows x columns`, but rather `majors x minors`,
|
||||
/// is represented by the following two arrays:
|
||||
///
|
||||
/// - `major_offsets`, an array of integers with length `m + 1`.
|
||||
/// - `minor_indices`, an array of integers with length `nnz`.
|
||||
///
|
||||
/// The invariants and relationship between `major_offsets` and `minor_indices` remain the same
|
||||
/// as for `row_offsets` and `col_indices` in the [CSR](`crate::csr::CsrMatrix`) format
|
||||
/// specification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
// TODO: Make SparsityPattern parametrized by index type
|
||||
// (need a solid abstraction for index types though)
|
||||
pub struct SparsityPattern {
|
||||
major_offsets: Vec<usize>,
|
||||
minor_indices: Vec<usize>,
|
||||
minor_dim: usize,
|
||||
}
|
||||
|
||||
impl SparsityPattern {
|
||||
/// Create a sparsity pattern of the given dimensions without explicitly stored entries.
|
||||
pub fn zeros(major_dim: usize, minor_dim: usize) -> Self {
|
||||
Self {
|
||||
major_offsets: vec![0; major_dim + 1],
|
||||
minor_indices: vec![],
|
||||
minor_dim,
|
||||
}
|
||||
}
|
||||
|
||||
/// The offsets for the major dimension.
|
||||
#[inline]
|
||||
pub fn major_offsets(&self) -> &[usize] {
|
||||
&self.major_offsets
|
||||
}
|
||||
|
||||
/// The indices for the minor dimension.
|
||||
#[inline]
|
||||
pub fn minor_indices(&self) -> &[usize] {
|
||||
&self.minor_indices
|
||||
}
|
||||
|
||||
/// The number of major lanes in the pattern.
|
||||
#[inline]
|
||||
pub fn major_dim(&self) -> usize {
|
||||
assert!(self.major_offsets.len() > 0);
|
||||
self.major_offsets.len() - 1
|
||||
}
|
||||
|
||||
/// The number of minor lanes in the pattern.
|
||||
#[inline]
|
||||
pub fn minor_dim(&self) -> usize {
|
||||
self.minor_dim
|
||||
}
|
||||
|
||||
/// The number of "non-zeros", i.e. explicitly stored entries in the pattern.
|
||||
#[inline]
|
||||
pub fn nnz(&self) -> usize {
|
||||
self.minor_indices.len()
|
||||
}
|
||||
|
||||
/// Get the lane at the given index.
|
||||
///
|
||||
/// Panics
|
||||
/// ------
|
||||
///
|
||||
/// Panics if `major_index` is out of bounds.
|
||||
#[inline]
|
||||
pub fn lane(&self, major_index: usize) -> &[usize] {
|
||||
self.get_lane(major_index).unwrap()
|
||||
}
|
||||
|
||||
/// Get the lane at the given index, or `None` if out of bounds.
|
||||
#[inline]
|
||||
pub fn get_lane(&self, major_index: usize) -> Option<&[usize]> {
|
||||
let offset_begin = *self.major_offsets().get(major_index)?;
|
||||
let offset_end = *self.major_offsets().get(major_index + 1)?;
|
||||
Some(&self.minor_indices()[offset_begin..offset_end])
|
||||
}
|
||||
|
||||
/// Try to construct a sparsity pattern from the given dimensions, major offsets
|
||||
/// and minor indices.
|
||||
///
|
||||
/// Returns an error if the data does not conform to the requirements.
|
||||
pub fn try_from_offsets_and_indices(
|
||||
major_dim: usize,
|
||||
minor_dim: usize,
|
||||
major_offsets: Vec<usize>,
|
||||
minor_indices: Vec<usize>,
|
||||
) -> Result<Self, SparsityPatternFormatError> {
|
||||
use SparsityPatternFormatError::*;
|
||||
|
||||
if major_offsets.len() != major_dim + 1 {
|
||||
return Err(InvalidOffsetArrayLength);
|
||||
}
|
||||
|
||||
// Check that the first and last offsets conform to the specification
|
||||
{
|
||||
let first_offset_ok = *major_offsets.first().unwrap() == 0;
|
||||
let last_offset_ok = *major_offsets.last().unwrap() == minor_indices.len();
|
||||
if !first_offset_ok || !last_offset_ok {
|
||||
return Err(InvalidOffsetFirstLast);
|
||||
}
|
||||
}
|
||||
|
||||
// Test that each lane has strictly monotonically increasing minor indices, i.e.
|
||||
// minor indices within a lane are sorted, unique. In addition, each minor index
|
||||
// must be in bounds with respect to the minor dimension.
|
||||
{
|
||||
for lane_idx in 0..major_dim {
|
||||
let range_start = major_offsets[lane_idx];
|
||||
let range_end = major_offsets[lane_idx + 1];
|
||||
|
||||
// Test that major offsets are monotonically increasing
|
||||
if range_start > range_end {
|
||||
return Err(NonmonotonicOffsets);
|
||||
}
|
||||
|
||||
let minor_indices = &minor_indices[range_start..range_end];
|
||||
|
||||
// We test for in-bounds, uniqueness and monotonicity at the same time
|
||||
// to ensure that we only visit each minor index once
|
||||
let mut iter = minor_indices.iter();
|
||||
let mut prev = None;
|
||||
|
||||
while let Some(next) = iter.next().copied() {
|
||||
if next >= minor_dim {
|
||||
return Err(MinorIndexOutOfBounds);
|
||||
}
|
||||
|
||||
if let Some(prev) = prev {
|
||||
if prev > next {
|
||||
return Err(NonmonotonicMinorIndices);
|
||||
} else if prev == next {
|
||||
return Err(DuplicateEntry);
|
||||
}
|
||||
}
|
||||
prev = Some(next);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
major_offsets,
|
||||
minor_indices,
|
||||
minor_dim,
|
||||
})
|
||||
}
|
||||
|
||||
/// An iterator over the explicitly stored "non-zero" entries (i, j).
|
||||
///
|
||||
/// The iteration happens in a lane-major fashion, meaning that the lane index i
|
||||
/// increases monotonically, and the minor index j increases monotonically within each
|
||||
/// lane i.
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
///
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::pattern::SparsityPattern;
|
||||
/// let offsets = vec![0, 2, 3, 4];
|
||||
/// let minor_indices = vec![0, 2, 1, 0];
|
||||
/// let pattern = SparsityPattern::try_from_offsets_and_indices(3, 4, offsets, minor_indices)
|
||||
/// .unwrap();
|
||||
///
|
||||
/// let entries: Vec<_> = pattern.entries().collect();
|
||||
/// assert_eq!(entries, vec![(0, 0), (0, 2), (1, 1), (2, 0)]);
|
||||
/// ```
|
||||
///
|
||||
pub fn entries(&self) -> SparsityPatternIter {
|
||||
SparsityPatternIter::from_pattern(self)
|
||||
}
|
||||
|
||||
/// Returns the raw offset and index data for the sparsity pattern.
|
||||
///
|
||||
/// Examples
|
||||
/// --------
|
||||
///
|
||||
/// ```
|
||||
/// # use nalgebra_sparse::pattern::SparsityPattern;
|
||||
/// let offsets = vec![0, 2, 3, 4];
|
||||
/// let minor_indices = vec![0, 2, 1, 0];
|
||||
/// let pattern = SparsityPattern::try_from_offsets_and_indices(
|
||||
/// 3,
|
||||
/// 4,
|
||||
/// offsets.clone(),
|
||||
/// minor_indices.clone())
|
||||
/// .unwrap();
|
||||
/// let (offsets2, minor_indices2) = pattern.disassemble();
|
||||
/// assert_eq!(offsets2, offsets);
|
||||
/// assert_eq!(minor_indices2, minor_indices);
|
||||
/// ```
|
||||
pub fn disassemble(self) -> (Vec<usize>, Vec<usize>) {
|
||||
(self.major_offsets, self.minor_indices)
|
||||
}
|
||||
|
||||
/// Computes the transpose of the sparsity pattern.
|
||||
///
|
||||
/// This is analogous to matrix transposition, i.e. an entry `(i, j)` becomes `(j, i)` in the
|
||||
/// new pattern.
|
||||
pub fn transpose(&self) -> Self {
|
||||
// By using unit () values, we can use the same routines as for CSR/CSC matrices
|
||||
let values = vec![(); self.nnz()];
|
||||
let (new_offsets, new_indices, _) = transpose_cs(
|
||||
self.major_dim(),
|
||||
self.minor_dim(),
|
||||
self.major_offsets(),
|
||||
self.minor_indices(),
|
||||
&values,
|
||||
);
|
||||
// TODO: Skip checks
|
||||
Self::try_from_offsets_and_indices(
|
||||
self.minor_dim(),
|
||||
self.major_dim(),
|
||||
new_offsets,
|
||||
new_indices,
|
||||
)
|
||||
.expect("Internal error: Transpose should never fail.")
|
||||
}
|
||||
}
|
||||
|
||||
/// Error type for `SparsityPattern` format errors.
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum SparsityPatternFormatError {
|
||||
/// Indicates an invalid number of offsets.
|
||||
///
|
||||
/// The number of offsets must be equal to (major_dim + 1).
|
||||
InvalidOffsetArrayLength,
|
||||
/// Indicates that the first or last entry in the offset array did not conform to
|
||||
/// specifications.
|
||||
///
|
||||
/// The first entry must be 0, and the last entry must be exactly one greater than the
|
||||
/// major dimension.
|
||||
InvalidOffsetFirstLast,
|
||||
/// Indicates that the major offsets are not monotonically increasing.
|
||||
NonmonotonicOffsets,
|
||||
/// One or more minor indices are out of bounds.
|
||||
MinorIndexOutOfBounds,
|
||||
/// One or more duplicate entries were detected.
|
||||
///
|
||||
/// Two entries are considered duplicates if they are part of the same major lane and have
|
||||
/// the same minor index.
|
||||
DuplicateEntry,
|
||||
/// Indicates that minor indices are not monotonically increasing within each lane.
|
||||
NonmonotonicMinorIndices,
|
||||
}
|
||||
|
||||
impl From<SparsityPatternFormatError> for SparseFormatError {
|
||||
fn from(err: SparsityPatternFormatError) -> Self {
|
||||
use crate::SparseFormatErrorKind;
|
||||
use crate::SparseFormatErrorKind::*;
|
||||
use SparsityPatternFormatError::DuplicateEntry as PatternDuplicateEntry;
|
||||
use SparsityPatternFormatError::*;
|
||||
match err {
|
||||
InvalidOffsetArrayLength
|
||||
| InvalidOffsetFirstLast
|
||||
| NonmonotonicOffsets
|
||||
| NonmonotonicMinorIndices => {
|
||||
SparseFormatError::from_kind_and_error(InvalidStructure, Box::from(err))
|
||||
}
|
||||
MinorIndexOutOfBounds => {
|
||||
SparseFormatError::from_kind_and_error(IndexOutOfBounds, Box::from(err))
|
||||
}
|
||||
PatternDuplicateEntry => SparseFormatError::from_kind_and_error(
|
||||
#[allow(unused_qualifications)]
|
||||
SparseFormatErrorKind::DuplicateEntry,
|
||||
Box::from(err),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SparsityPatternFormatError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SparsityPatternFormatError::InvalidOffsetArrayLength => {
|
||||
write!(f, "Length of offset array is not equal to (major_dim + 1).")
|
||||
}
|
||||
SparsityPatternFormatError::InvalidOffsetFirstLast => {
|
||||
write!(f, "First or last offset is incompatible with format.")
|
||||
}
|
||||
SparsityPatternFormatError::NonmonotonicOffsets => {
|
||||
write!(f, "Offsets are not monotonically increasing.")
|
||||
}
|
||||
SparsityPatternFormatError::MinorIndexOutOfBounds => {
|
||||
write!(f, "A minor index is out of bounds.")
|
||||
}
|
||||
SparsityPatternFormatError::DuplicateEntry => {
|
||||
write!(f, "Input data contains duplicate entries.")
|
||||
}
|
||||
SparsityPatternFormatError::NonmonotonicMinorIndices => {
|
||||
write!(
|
||||
f,
|
||||
"Minor indices are not monotonically increasing within each lane."
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for SparsityPatternFormatError {}
|
||||
|
||||
/// Iterator type for iterating over entries in a sparsity pattern.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SparsityPatternIter<'a> {
|
||||
// See implementation of Iterator::next for an explanation of how these members are used
|
||||
major_offsets: &'a [usize],
|
||||
minor_indices: &'a [usize],
|
||||
current_lane_idx: usize,
|
||||
remaining_minors_in_lane: &'a [usize],
|
||||
}
|
||||
|
||||
impl<'a> SparsityPatternIter<'a> {
|
||||
fn from_pattern(pattern: &'a SparsityPattern) -> Self {
|
||||
let first_lane_end = pattern.major_offsets().get(1).unwrap_or(&0);
|
||||
let minors_in_first_lane = &pattern.minor_indices()[0..*first_lane_end];
|
||||
Self {
|
||||
major_offsets: pattern.major_offsets(),
|
||||
minor_indices: pattern.minor_indices(),
|
||||
current_lane_idx: 0,
|
||||
remaining_minors_in_lane: minors_in_first_lane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for SparsityPatternIter<'a> {
|
||||
type Item = (usize, usize);
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// We ensure fast iteration across each lane by iteratively "draining" a slice
|
||||
// corresponding to the remaining column indices in the particular lane.
|
||||
// When we reach the end of this slice, we are at the end of a lane,
|
||||
// and we must do some bookkeeping for preparing the iteration of the next lane
|
||||
// (or stop iteration if we're through all lanes).
|
||||
// This way we can avoid doing unnecessary bookkeeping on every iteration,
|
||||
// instead paying a small price whenever we jump to a new lane.
|
||||
if let Some(minor_idx) = self.remaining_minors_in_lane.first() {
|
||||
let item = Some((self.current_lane_idx, *minor_idx));
|
||||
self.remaining_minors_in_lane = &self.remaining_minors_in_lane[1..];
|
||||
item
|
||||
} else {
|
||||
loop {
|
||||
// Keep skipping lanes until we found a non-empty lane or there are no more lanes
|
||||
if self.current_lane_idx + 2 >= self.major_offsets.len() {
|
||||
// We've processed all lanes, so we're at the end of the iterator
|
||||
// (note: keep in mind that offsets.len() == major_dim() + 1, hence we need +2)
|
||||
return None;
|
||||
} else {
|
||||
// Bump lane index and check if the lane is non-empty
|
||||
self.current_lane_idx += 1;
|
||||
let lower = self.major_offsets[self.current_lane_idx];
|
||||
let upper = self.major_offsets[self.current_lane_idx + 1];
|
||||
if upper > lower {
|
||||
self.remaining_minors_in_lane = &self.minor_indices[(lower + 1)..upper];
|
||||
return Some((self.current_lane_idx, self.minor_indices[lower]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,374 @@
|
|||
//! Functionality for integrating `nalgebra-sparse` with `proptest`.
|
||||
//!
|
||||
//! **This module is only available if the `proptest-support` feature is enabled**.
|
||||
//!
|
||||
//! The strategies provided here are generally expected to be able to generate the entire range
|
||||
//! of possible outputs given the constraints on dimensions and values. However, there are no
|
||||
//! particular guarantees on the distribution of possible values.
|
||||
|
||||
// Contains some patched code from proptest that we can remove in the (hopefully near) future.
|
||||
// See docs in file for more details.
|
||||
mod proptest_patched;
|
||||
|
||||
use crate::coo::CooMatrix;
|
||||
use crate::csc::CscMatrix;
|
||||
use crate::csr::CsrMatrix;
|
||||
use crate::pattern::SparsityPattern;
|
||||
use nalgebra::proptest::DimRange;
|
||||
use nalgebra::{Dim, Scalar};
|
||||
use proptest::collection::{btree_set, hash_map, vec};
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::Index;
|
||||
use std::cmp::min;
|
||||
use std::iter::repeat;
|
||||
|
||||
fn dense_row_major_coord_strategy(
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize,
|
||||
) -> impl Strategy<Value = Vec<(usize, usize)>> {
|
||||
assert!(nnz <= nrows * ncols);
|
||||
let mut booleans = vec![true; nnz];
|
||||
booleans.append(&mut vec![false; (nrows * ncols) - nnz]);
|
||||
// Make sure that exactly `nnz` of the booleans are true
|
||||
|
||||
// TODO: We cannot use the below code because of a bug in proptest, see
|
||||
// https://github.com/AltSysrq/proptest/pull/217
|
||||
// so for now we're using a patched version of the Shuffle adapter
|
||||
// (see also docs in `proptest_patched`
|
||||
// Just(booleans)
|
||||
// // Need to shuffle to make sure they are randomly distributed
|
||||
// .prop_shuffle()
|
||||
|
||||
proptest_patched::Shuffle(Just(booleans)).prop_map(move |booleans| {
|
||||
booleans
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, is_entry)| {
|
||||
if is_entry {
|
||||
// Convert linear index to row/col pair
|
||||
let i = index / ncols;
|
||||
let j = index % ncols;
|
||||
Some((i, j))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for generating `nnz` triplets.
|
||||
///
|
||||
/// This strategy should generally only be used when `nnz` is close to `nrows * ncols`.
|
||||
fn dense_triplet_strategy<T>(
|
||||
value_strategy: T,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize,
|
||||
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
assert!(nnz <= nrows * ncols);
|
||||
|
||||
// Construct a number of booleans of which exactly `nnz` are true.
|
||||
let booleans: Vec<_> = repeat(true)
|
||||
.take(nnz)
|
||||
.chain(repeat(false))
|
||||
.take(nrows * ncols)
|
||||
.collect();
|
||||
|
||||
Just(booleans)
|
||||
// Shuffle the booleans so that they are randomly distributed
|
||||
.prop_shuffle()
|
||||
// Convert the booleans into a list of coordinate pairs
|
||||
.prop_map(move |booleans| {
|
||||
booleans
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, is_entry)| {
|
||||
if is_entry {
|
||||
// Convert linear index to row/col pair
|
||||
let i = index / ncols;
|
||||
let j = index % ncols;
|
||||
Some((i, j))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
// Assign values to each coordinate pair in order to generate a list of triplets
|
||||
.prop_flat_map(move |coords| {
|
||||
vec![value_strategy.clone(); coords.len()].prop_map(move |values| {
|
||||
coords
|
||||
.clone()
|
||||
.into_iter()
|
||||
.zip(values)
|
||||
.map(|((i, j), v)| (i, j, v))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for generating `nnz` triplets.
|
||||
///
|
||||
/// This strategy should generally only be used when `nnz << nrows * ncols`. If `nnz` is too
|
||||
/// close to `nrows * ncols` it may fail due to excessive rejected samples.
|
||||
fn sparse_triplet_strategy<T>(
|
||||
value_strategy: T,
|
||||
nrows: usize,
|
||||
ncols: usize,
|
||||
nnz: usize,
|
||||
) -> impl Strategy<Value = Vec<(usize, usize, T::Value)>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
// Have to handle the zero case: proptest doesn't like empty ranges (i.e. 0 .. 0)
|
||||
let row_index_strategy = if nrows > 0 { 0..nrows } else { 0..1 };
|
||||
let col_index_strategy = if ncols > 0 { 0..ncols } else { 0..1 };
|
||||
let coord_strategy = (row_index_strategy, col_index_strategy);
|
||||
hash_map(coord_strategy, value_strategy.clone(), nnz)
|
||||
.prop_map(|hash_map| {
|
||||
let triplets: Vec<_> = hash_map.into_iter().map(|((i, j), v)| (i, j, v)).collect();
|
||||
triplets
|
||||
})
|
||||
// Although order in the hash map is unspecified, it's not necessarily *random*
|
||||
// - or, in particular, it does not necessarily sample the whole space of possible outcomes -
|
||||
// so we additionally shuffle the triplets
|
||||
.prop_shuffle()
|
||||
}
|
||||
|
||||
/// A strategy for producing COO matrices without duplicate entries.
|
||||
///
|
||||
/// The values of the matrix are picked from the provided `value_strategy`, while the size of the
|
||||
/// generated matrices is determined by the ranges `rows` and `cols`. The number of explicitly
|
||||
/// stored entries is bounded from above by `max_nonzeros`. Note that the matrix might still
|
||||
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
|
||||
pub fn coo_no_duplicates<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
(
|
||||
rows.into().to_range_inclusive(),
|
||||
cols.into().to_range_inclusive(),
|
||||
)
|
||||
.prop_flat_map(move |(nrows, ncols)| {
|
||||
let max_nonzeros = min(max_nonzeros, nrows * ncols);
|
||||
let size_range = 0..=max_nonzeros;
|
||||
let value_strategy = value_strategy.clone();
|
||||
|
||||
size_range
|
||||
.prop_flat_map(move |nnz| {
|
||||
let value_strategy = value_strategy.clone();
|
||||
if nnz as f64 > 0.10 * (nrows as f64) * (ncols as f64) {
|
||||
// If the number of nnz is sufficiently dense, then use the dense
|
||||
// sample strategy
|
||||
dense_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||
} else {
|
||||
// Otherwise, use a hash map strategy so that we can get a sparse sampling
|
||||
// (so that complexity is rather on the order of max_nnz than nrows * ncols)
|
||||
sparse_triplet_strategy(value_strategy, nrows, ncols, nnz).boxed()
|
||||
}
|
||||
})
|
||||
.prop_map(move |triplets| {
|
||||
let mut coo = CooMatrix::new(nrows, ncols);
|
||||
for (i, j, v) in triplets {
|
||||
coo.push(i, j, v);
|
||||
}
|
||||
coo
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for producing COO matrices with duplicate entries.
|
||||
///
|
||||
/// The values of the matrix are picked from the provided `value_strategy`, while the size of the
|
||||
/// generated matrices is determined by the ranges `rows` and `cols`. Note that the values
|
||||
/// only apply to individual entries, and since this strategy can generate duplicate entries,
|
||||
/// the matrix will generally have values outside the range determined by `value_strategy` when
|
||||
/// converted to other formats, since the duplicate entries are summed together in this case.
|
||||
///
|
||||
/// The number of explicitly stored entries is bounded from above by `max_nonzeros`. The maximum
|
||||
/// number of duplicate entries is determined by `max_duplicates`. Note that the matrix might still
|
||||
/// contain explicitly stored zeroes if the value strategy is capable of generating zero values.
|
||||
pub fn coo_with_duplicates<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
max_duplicates: usize,
|
||||
) -> impl Strategy<Value = CooMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
let coo_strategy = coo_no_duplicates(value_strategy.clone(), rows, cols, max_nonzeros);
|
||||
let duplicate_strategy = vec((any::<Index>(), value_strategy.clone()), 0..=max_duplicates);
|
||||
(coo_strategy, duplicate_strategy)
|
||||
.prop_flat_map(|(coo, duplicates)| {
|
||||
let mut triplets: Vec<(usize, usize, T::Value)> = coo
|
||||
.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, v.clone()))
|
||||
.collect();
|
||||
if !triplets.is_empty() {
|
||||
let duplicates_iter: Vec<_> = duplicates
|
||||
.into_iter()
|
||||
.map(|(idx, val)| {
|
||||
let (i, j, _) = idx.get(&triplets);
|
||||
(*i, *j, val)
|
||||
})
|
||||
.collect();
|
||||
triplets.extend(duplicates_iter);
|
||||
}
|
||||
// Make sure to shuffle so that the duplicates get mixed in with the non-duplicates
|
||||
let shuffled = Just(triplets).prop_shuffle();
|
||||
(Just(coo.nrows()), Just(coo.ncols()), shuffled)
|
||||
})
|
||||
.prop_map(move |(nrows, ncols, triplets)| {
|
||||
let mut coo = CooMatrix::new(nrows, ncols);
|
||||
for (i, j, v) in triplets {
|
||||
coo.push(i, j, v);
|
||||
}
|
||||
coo
|
||||
})
|
||||
}
|
||||
|
||||
fn sparsity_pattern_from_row_major_coords<I>(
|
||||
nmajor: usize,
|
||||
nminor: usize,
|
||||
coords: I,
|
||||
) -> SparsityPattern
|
||||
where
|
||||
I: Iterator<Item = (usize, usize)> + ExactSizeIterator,
|
||||
{
|
||||
let mut minors = Vec::with_capacity(coords.len());
|
||||
let mut offsets = Vec::with_capacity(nmajor + 1);
|
||||
let mut current_major = 0;
|
||||
offsets.push(0);
|
||||
for (idx, (i, j)) in coords.enumerate() {
|
||||
assert!(i >= current_major);
|
||||
assert!(
|
||||
i < nmajor && j < nminor,
|
||||
"Generated coords are out of bounds"
|
||||
);
|
||||
while current_major < i {
|
||||
offsets.push(idx);
|
||||
current_major += 1;
|
||||
}
|
||||
minors.push(j);
|
||||
}
|
||||
|
||||
while current_major < nmajor {
|
||||
offsets.push(minors.len());
|
||||
current_major += 1;
|
||||
}
|
||||
|
||||
assert_eq!(offsets.first().unwrap(), &0);
|
||||
assert_eq!(offsets.len(), nmajor + 1);
|
||||
|
||||
SparsityPattern::try_from_offsets_and_indices(nmajor, nminor, offsets, minors)
|
||||
.expect("Internal error: Generated sparsity pattern is invalid")
|
||||
}
|
||||
|
||||
/// A strategy for generating sparsity patterns.
|
||||
pub fn sparsity_pattern(
|
||||
major_lanes: impl Into<DimRange>,
|
||||
minor_lanes: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = SparsityPattern> {
|
||||
(
|
||||
major_lanes.into().to_range_inclusive(),
|
||||
minor_lanes.into().to_range_inclusive(),
|
||||
)
|
||||
.prop_flat_map(move |(nmajor, nminor)| {
|
||||
let max_nonzeros = min(nmajor * nminor, max_nonzeros);
|
||||
(Just(nmajor), Just(nminor), 0..=max_nonzeros)
|
||||
})
|
||||
.prop_flat_map(move |(nmajor, nminor, nnz)| {
|
||||
if 10 * nnz < nmajor * nminor {
|
||||
// If nnz is small compared to a dense matrix, then use a sparse sampling strategy
|
||||
btree_set((0..nmajor, 0..nminor), nnz)
|
||||
.prop_map(move |coords| {
|
||||
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords.into_iter())
|
||||
})
|
||||
.boxed()
|
||||
} else {
|
||||
// If the required number of nonzeros is sufficiently dense,
|
||||
// we instead use a dense sampling
|
||||
dense_row_major_coord_strategy(nmajor, nminor, nnz)
|
||||
.prop_map(move |coords| {
|
||||
let coords = coords.into_iter();
|
||||
sparsity_pattern_from_row_major_coords(nmajor, nminor, coords)
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for generating CSR matrices.
|
||||
pub fn csr<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = CsrMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
let rows = rows.into();
|
||||
let cols = cols.into();
|
||||
sparsity_pattern(
|
||||
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||
max_nonzeros,
|
||||
)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CsrMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CsrMatrix is invalid")
|
||||
})
|
||||
}
|
||||
|
||||
/// A strategy for generating CSC matrices.
|
||||
pub fn csc<T>(
|
||||
value_strategy: T,
|
||||
rows: impl Into<DimRange>,
|
||||
cols: impl Into<DimRange>,
|
||||
max_nonzeros: usize,
|
||||
) -> impl Strategy<Value = CscMatrix<T::Value>>
|
||||
where
|
||||
T: Strategy + Clone + 'static,
|
||||
T::Value: Scalar,
|
||||
{
|
||||
let rows = rows.into();
|
||||
let cols = cols.into();
|
||||
sparsity_pattern(
|
||||
cols.lower_bound().value()..=cols.upper_bound().value(),
|
||||
rows.lower_bound().value()..=rows.upper_bound().value(),
|
||||
max_nonzeros,
|
||||
)
|
||||
.prop_flat_map(move |pattern| {
|
||||
let nnz = pattern.nnz();
|
||||
let values = vec![value_strategy.clone(); nnz];
|
||||
(Just(pattern), values)
|
||||
})
|
||||
.prop_map(|(pattern, values)| {
|
||||
CscMatrix::try_from_pattern_and_values(pattern, values)
|
||||
.expect("Internal error: Generated CscMatrix is invalid")
|
||||
})
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
//! Contains a modified implementation of `proptest::strategy::Shuffle`.
|
||||
//!
|
||||
//! The current implementation in `proptest` does not generate all permutations, which is
|
||||
//! problematic for our proptest generators. The issue has been fixed in
|
||||
//! https://github.com/AltSysrq/proptest/pull/217
|
||||
//! but it has yet to be merged and released. As soon as this fix makes it into a new release,
|
||||
//! the modified code here can be removed.
|
||||
//!
|
||||
/*!
|
||||
This code has been copied and adapted from
|
||||
https://github.com/AltSysrq/proptest/blob/master/proptest/src/strategy/shuffle.rs
|
||||
The original licensing text is:
|
||||
|
||||
//-
|
||||
// Copyright 2017 Jason Lingle
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
*/
|
||||
|
||||
use proptest::num;
|
||||
use proptest::prelude::Rng;
|
||||
use proptest::strategy::{NewTree, Shuffleable, Strategy, ValueTree};
|
||||
use proptest::test_runner::{TestRng, TestRunner};
|
||||
use std::cell::Cell;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[must_use = "strategies do nothing unless used"]
|
||||
pub struct Shuffle<S>(pub(super) S);
|
||||
|
||||
impl<S: Strategy> Strategy for Shuffle<S>
|
||||
where
|
||||
S::Value: Shuffleable,
|
||||
{
|
||||
type Tree = ShuffleValueTree<S::Tree>;
|
||||
type Value = S::Value;
|
||||
|
||||
fn new_tree(&self, runner: &mut TestRunner) -> NewTree<Self> {
|
||||
let rng = runner.new_rng();
|
||||
|
||||
self.0.new_tree(runner).map(|inner| ShuffleValueTree {
|
||||
inner,
|
||||
rng,
|
||||
dist: Cell::new(None),
|
||||
simplifying_inner: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ShuffleValueTree<V> {
|
||||
inner: V,
|
||||
rng: TestRng,
|
||||
dist: Cell<Option<num::usize::BinarySearch>>,
|
||||
simplifying_inner: bool,
|
||||
}
|
||||
|
||||
impl<V: ValueTree> ShuffleValueTree<V>
|
||||
where
|
||||
V::Value: Shuffleable,
|
||||
{
|
||||
fn init_dist(&self, dflt: usize) -> usize {
|
||||
if self.dist.get().is_none() {
|
||||
self.dist.set(Some(num::usize::BinarySearch::new(dflt)));
|
||||
}
|
||||
|
||||
self.dist.get().unwrap().current()
|
||||
}
|
||||
|
||||
fn force_init_dist(&self) {
|
||||
if self.dist.get().is_none() {
|
||||
let _ = self.init_dist(self.current().shuffle_len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<V: ValueTree> ValueTree for ShuffleValueTree<V>
|
||||
where
|
||||
V::Value: Shuffleable,
|
||||
{
|
||||
type Value = V::Value;
|
||||
|
||||
fn current(&self) -> V::Value {
|
||||
let mut value = self.inner.current();
|
||||
let len = value.shuffle_len();
|
||||
// The maximum distance to swap elements. This could be larger than
|
||||
// `value` if `value` has reduced size during shrinking; that's OK,
|
||||
// since we only use this to filter swaps.
|
||||
let max_swap = self.init_dist(len);
|
||||
|
||||
// If empty collection or all swaps will be filtered out, there's
|
||||
// nothing to shuffle.
|
||||
if 0 == len || 0 == max_swap {
|
||||
return value;
|
||||
}
|
||||
|
||||
let mut rng = self.rng.clone();
|
||||
|
||||
for start_index in 0..len - 1 {
|
||||
// Determine the other index to be swapped, then skip the swap if
|
||||
// it is too far. This ordering is critical, as it ensures that we
|
||||
// generate the same sequence of random numbers every time.
|
||||
|
||||
// NOTE: The below line is the whole reason for the existence of this adapted code
|
||||
// We need to be able to swap with the same element, so that some elements remain in
|
||||
// place rather being swapped
|
||||
// let end_index = rng.gen_range(start_index + 1..len);
|
||||
let end_index = rng.gen_range(start_index..len);
|
||||
if end_index - start_index <= max_swap {
|
||||
value.shuffle_swap(start_index, end_index);
|
||||
}
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
|
||||
fn simplify(&mut self) -> bool {
|
||||
if self.simplifying_inner {
|
||||
self.inner.simplify()
|
||||
} else {
|
||||
// Ensure that we've initialised `dist` to *something* to give
|
||||
// consistent non-panicking behaviour even if called in an
|
||||
// unexpected sequence.
|
||||
self.force_init_dist();
|
||||
if self.dist.get_mut().as_mut().unwrap().simplify() {
|
||||
true
|
||||
} else {
|
||||
self.simplifying_inner = true;
|
||||
self.inner.simplify()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn complicate(&mut self) -> bool {
|
||||
if self.simplifying_inner {
|
||||
self.inner.complicate()
|
||||
} else {
|
||||
self.force_init_dist();
|
||||
self.dist.get_mut().as_mut().unwrap().complicate()
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{csc, csr};
|
||||
use proptest::strategy::Strategy;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt::Debug;
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! assert_panics {
|
||||
($e:expr) => {{
|
||||
use std::panic::catch_unwind;
|
||||
use std::stringify;
|
||||
let expr_string = stringify!($e);
|
||||
|
||||
// Note: We cannot manipulate the panic hook here, because it is global and the test
|
||||
// suite is run in parallel, which leads to race conditions in the sense
|
||||
// that some regular tests that panic might not output anything anymore.
|
||||
// Unfortunately this means that output is still printed to stdout if
|
||||
// we run cargo test -- --nocapture. But Cargo does not forward this if the test
|
||||
// binary is not run with nocapture, so it is somewhat acceptable nonetheless.
|
||||
|
||||
let result = catch_unwind(|| $e);
|
||||
if result.is_ok() {
|
||||
panic!(
|
||||
"assert_panics!({}) failed: the expression did not panic.",
|
||||
expr_string
|
||||
);
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
pub const PROPTEST_MATRIX_DIM: RangeInclusive<usize> = 0..=6;
|
||||
pub const PROPTEST_MAX_NNZ: usize = 40;
|
||||
pub const PROPTEST_I32_VALUE_STRATEGY: RangeInclusive<i32> = -5..=5;
|
||||
|
||||
pub fn value_strategy<T>() -> RangeInclusive<T>
|
||||
where
|
||||
T: TryFrom<i32>,
|
||||
T::Error: Debug,
|
||||
{
|
||||
let (start, end) = (
|
||||
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||
);
|
||||
T::try_from(*start).unwrap()..=T::try_from(*end).unwrap()
|
||||
}
|
||||
|
||||
pub fn non_zero_i32_value_strategy() -> impl Strategy<Value = i32> {
|
||||
let (start, end) = (
|
||||
PROPTEST_I32_VALUE_STRATEGY.start(),
|
||||
PROPTEST_I32_VALUE_STRATEGY.end(),
|
||||
);
|
||||
assert!(start < &0);
|
||||
assert!(end > &0);
|
||||
// Note: we don't use RangeInclusive for the second range, because then we'd have different
|
||||
// types, which would require boxing
|
||||
(*start..0).prop_union(1..*end + 1)
|
||||
}
|
||||
|
||||
pub fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||
csr(
|
||||
PROPTEST_I32_VALUE_STRATEGY,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MAX_NNZ,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||
csc(
|
||||
PROPTEST_I32_VALUE_STRATEGY,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MAX_NNZ,
|
||||
)
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
//! Unit tests
|
||||
#[cfg(any(not(feature = "proptest-support"), not(feature = "compare")))]
|
||||
compile_error!("Tests must be run with features `proptest-support` and `compare`");
|
||||
|
||||
mod unit_tests;
|
||||
|
||||
#[macro_use]
|
||||
pub mod common;
|
|
@ -0,0 +1,8 @@
|
|||
# Seeds for failure cases proptest has generated in the past. It is
|
||||
# automatically read and these particular cases re-run before any
|
||||
# novel cases are generated.
|
||||
#
|
||||
# It is recommended to check this file in to source control so that
|
||||
# everyone who runs the test benefits from these saved cases.
|
||||
cc 3f71c8edc555965e521e3aaf58c736240a0e333c3a9d54e8a836d7768c371215 # shrinks to matrix = CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0], minor_indices: [], minor_dim: 0 }, values: [] } }
|
||||
cc aef645e3184b814ef39fbb10234f12e6ff502ab515dabefafeedab5895e22b12 # shrinks to (matrix, rhs) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 4, 7, 11, 14], minor_indices: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 0, 2, 3], minor_dim: 4 }, values: [1.0, 0.0, 0.0, 0.0, 0.0, 40.90124126326177, 36.975170911665906, 0.0, 36.975170911665906, 42.51062858727923, -12.984115201530539, 0.0, -12.984115201530539, 27.73953543265418] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, -4.05763092330143], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 1 } } })
|
|
@ -0,0 +1,117 @@
|
|||
#![cfg_attr(rustfmt, rustfmt_skip)]
|
||||
use crate::common::{value_strategy, PROPTEST_MATRIX_DIM, PROPTEST_MAX_NNZ};
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::factorization::{CscCholesky};
|
||||
use nalgebra_sparse::proptest::csc;
|
||||
use nalgebra::{Matrix5, Vector5, Cholesky, DMatrix};
|
||||
use nalgebra::proptest::matrix;
|
||||
|
||||
use proptest::prelude::*;
|
||||
use matrixcompare::{assert_matrix_eq, prop_assert_matrix_eq};
|
||||
|
||||
fn positive_definite() -> impl Strategy<Value=CscMatrix<f64>> {
|
||||
let csc_f64 = csc(value_strategy::<f64>(),
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MATRIX_DIM,
|
||||
PROPTEST_MAX_NNZ);
|
||||
csc_f64
|
||||
.prop_map(|x| {
|
||||
// Add a small multiple of the identity to ensure positive definiteness
|
||||
x.transpose() * &x + CscMatrix::identity(x.ncols())
|
||||
})
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn cholesky_correct_for_positive_definite_matrices(
|
||||
matrix in positive_definite()
|
||||
) {
|
||||
let cholesky = CscCholesky::factor(&matrix).unwrap();
|
||||
let l = cholesky.take_l();
|
||||
let matrix_reconstructed = &l * l.transpose();
|
||||
|
||||
prop_assert_matrix_eq!(matrix_reconstructed, matrix, comp = abs, tol = 1e-8);
|
||||
|
||||
let is_lower_triangular = l.triplet_iter().all(|(i, j, _)| j <= i);
|
||||
prop_assert!(is_lower_triangular);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cholesky_solve_positive_definite(
|
||||
(matrix, rhs) in positive_definite()
|
||||
.prop_flat_map(|csc| {
|
||||
let rhs = matrix(value_strategy::<f64>(), csc.nrows(), PROPTEST_MATRIX_DIM);
|
||||
(Just(csc), rhs)
|
||||
})
|
||||
) {
|
||||
let cholesky = CscCholesky::factor(&matrix).unwrap();
|
||||
|
||||
// solve_mut
|
||||
{
|
||||
let mut x = rhs.clone();
|
||||
cholesky.solve_mut(&mut x);
|
||||
prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12);
|
||||
}
|
||||
|
||||
// solve
|
||||
{
|
||||
let x = cholesky.solve(&rhs);
|
||||
prop_assert_matrix_eq!(&matrix * &x, rhs, comp=abs, tol=1e-12);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// This is a test ported from nalgebra's "sparse" module, for the original CsCholesky impl
|
||||
#[test]
|
||||
fn cs_cholesky() {
|
||||
let mut a = Matrix5::new(
|
||||
40.0, 0.0, 0.0, 0.0, 0.0,
|
||||
2.0, 60.0, 0.0, 0.0, 0.0,
|
||||
1.0, 0.0, 11.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 50.0, 0.0,
|
||||
1.0, 0.0, 0.0, 4.0, 10.0
|
||||
);
|
||||
a.fill_upper_triangle_with_lower_triangle();
|
||||
test_cholesky(a);
|
||||
|
||||
let a = Matrix5::from_diagonal(&Vector5::new(40.0, 60.0, 11.0, 50.0, 10.0));
|
||||
test_cholesky(a);
|
||||
|
||||
let mut a = Matrix5::new(
|
||||
40.0, 0.0, 0.0, 0.0, 0.0,
|
||||
2.0, 60.0, 0.0, 0.0, 0.0,
|
||||
1.0, 0.0, 11.0, 0.0, 0.0,
|
||||
1.0, 0.0, 0.0, 50.0, 0.0,
|
||||
0.0, 0.0, 0.0, 4.0, 10.0
|
||||
);
|
||||
a.fill_upper_triangle_with_lower_triangle();
|
||||
test_cholesky(a);
|
||||
|
||||
let mut a = Matrix5::new(
|
||||
2.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 2.0, 0.0, 0.0, 0.0,
|
||||
1.0, 1.0, 2.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 2.0, 0.0,
|
||||
1.0, 1.0, 0.0, 0.0, 2.0
|
||||
);
|
||||
a.fill_upper_triangle_with_lower_triangle();
|
||||
// Test crate::new, left_looking, and up_looking implementations.
|
||||
test_cholesky(a);
|
||||
}
|
||||
|
||||
fn test_cholesky(a: Matrix5<f64>) {
|
||||
// TODO: Test "refactor"
|
||||
|
||||
let cs_a = CscMatrix::from(&a);
|
||||
|
||||
let chol_a = Cholesky::new(a).unwrap();
|
||||
let chol_cs_a = CscCholesky::factor(&cs_a).unwrap();
|
||||
|
||||
let l = chol_a.l();
|
||||
let cs_l = chol_cs_a.take_l();
|
||||
|
||||
let l = DMatrix::from_iterator(l.nrows(), l.ncols(), l.iter().cloned());
|
||||
let cs_l_mat = DMatrix::from(&cs_l);
|
||||
assert_matrix_eq!(l, cs_l_mat, comp = abs, tol = 1e-12);
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
# Seeds for failure cases proptest has generated in the past. It is
|
||||
# automatically read and these particular cases re-run before any
|
||||
# novel cases are generated.
|
||||
#
|
||||
# It is recommended to check this file in to source control so that
|
||||
# everyone who runs the test benefits from these saved cases.
|
||||
cc 07cb95127d2700ff2000157938e351ce2b43f3e6419d69b00726abfc03e682bd # shrinks to coo = CooMatrix { nrows: 4, ncols: 5, row_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0], col_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 4, 3], values: [1, -5, -4, -5, 1, 2, 4, -4, -4, -5, 2, -2, 4, -4] }
|
||||
cc 8fdaf70d6091d89a6617573547745e9802bb9c1ce7c6ec7ad4f301cd05d54c5d # shrinks to dense = Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }
|
||||
cc 6961760ac7915b57a28230524cea7e9bfcea4f31790e3c0569ea74af904c2d79 # shrinks to coo = CooMatrix { nrows: 6, ncols: 6, row_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0], col_indices: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0], values: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0] }
|
||||
cc c9a1af218f7a974f1fda7b8909c2635d735eedbfe953082ef6b0b92702bf6d1b # shrinks to dense = Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 5 } } }
|
|
@ -0,0 +1,452 @@
|
|||
use crate::common::csc_strategy;
|
||||
use nalgebra::proptest::matrix;
|
||||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::convert::serial::{
|
||||
convert_coo_csc, convert_coo_csr, convert_coo_dense, convert_csc_coo, convert_csc_csr,
|
||||
convert_csc_dense, convert_csr_coo, convert_csr_csc, convert_csr_dense, convert_dense_coo,
|
||||
convert_dense_csc, convert_dense_csr,
|
||||
};
|
||||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::proptest::{coo_no_duplicates, coo_with_duplicates, csc, csr};
|
||||
use proptest::prelude::*;
|
||||
|
||||
#[test]
|
||||
fn test_convert_dense_coo() {
|
||||
// No duplicates
|
||||
{
|
||||
#[rustfmt::skip]
|
||||
let entries = &[1, 0, 3,
|
||||
0, 5, 0];
|
||||
// The COO representation of a dense matrix is not unique.
|
||||
// Here we implicitly test that the coo matrix is indeed constructed from column-major
|
||||
// iteration of the dense matrix.
|
||||
let dense = DMatrix::from_row_slice(2, 3, entries);
|
||||
let coo = CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(CooMatrix::from(&dense), coo);
|
||||
assert_eq!(DMatrix::from(&coo), dense);
|
||||
}
|
||||
|
||||
// Duplicates
|
||||
// No duplicates
|
||||
{
|
||||
#[rustfmt::skip]
|
||||
let entries = &[1, 0, 3,
|
||||
0, 5, 0];
|
||||
// The COO representation of a dense matrix is not unique.
|
||||
// Here we implicitly test that the coo matrix is indeed constructed from column-major
|
||||
// iteration of the dense matrix.
|
||||
let dense = DMatrix::from_row_slice(2, 3, entries);
|
||||
let coo_no_dup =
|
||||
CooMatrix::try_from_triplets(2, 3, vec![0, 1, 0], vec![0, 1, 2], vec![1, 5, 3])
|
||||
.unwrap();
|
||||
let coo_dup = CooMatrix::try_from_triplets(
|
||||
2,
|
||||
3,
|
||||
vec![0, 1, 0, 1],
|
||||
vec![0, 1, 2, 1],
|
||||
vec![1, -2, 3, 7],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(CooMatrix::from(&dense), coo_no_dup);
|
||||
assert_eq!(DMatrix::from(&coo_dup), dense);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_coo_csr() {
|
||||
// No duplicates
|
||||
{
|
||||
let coo = {
|
||||
let mut coo = CooMatrix::new(3, 4);
|
||||
coo.push(1, 3, 4);
|
||||
coo.push(0, 1, 2);
|
||||
coo.push(2, 0, 1);
|
||||
coo.push(2, 3, 2);
|
||||
coo.push(2, 2, 1);
|
||||
coo
|
||||
};
|
||||
|
||||
let expected_csr = CsrMatrix::try_from_csr_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 5],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![2, 4, 1, 1, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||
}
|
||||
|
||||
// Duplicates
|
||||
{
|
||||
let coo = {
|
||||
let mut coo = CooMatrix::new(3, 4);
|
||||
coo.push(1, 3, 4);
|
||||
coo.push(2, 3, 2);
|
||||
coo.push(0, 1, 2);
|
||||
coo.push(2, 0, 1);
|
||||
coo.push(2, 3, 2);
|
||||
coo.push(0, 1, 3);
|
||||
coo.push(2, 2, 1);
|
||||
coo
|
||||
};
|
||||
|
||||
let expected_csr = CsrMatrix::try_from_csr_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 5],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![5, 4, 1, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csr(&coo), expected_csr);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_csr_coo() {
|
||||
let csr = CsrMatrix::try_from_csr_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 5],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![5, 4, 1, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_coo = CooMatrix::try_from_triplets(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 2, 2],
|
||||
vec![1, 3, 0, 2, 3],
|
||||
vec![5, 4, 1, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_csr_coo(&csr), expected_coo);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_coo_csc() {
|
||||
// No duplicates
|
||||
{
|
||||
let coo = {
|
||||
let mut coo = CooMatrix::new(3, 4);
|
||||
coo.push(1, 3, 4);
|
||||
coo.push(0, 1, 2);
|
||||
coo.push(2, 0, 1);
|
||||
coo.push(2, 3, 2);
|
||||
coo.push(2, 2, 1);
|
||||
coo
|
||||
};
|
||||
|
||||
let expected_csc = CscMatrix::try_from_csc_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 3, 5],
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![1, 2, 1, 4, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||
}
|
||||
|
||||
// Duplicates
|
||||
{
|
||||
let coo = {
|
||||
let mut coo = CooMatrix::new(3, 4);
|
||||
coo.push(1, 3, 4);
|
||||
coo.push(2, 3, 2);
|
||||
coo.push(0, 1, 2);
|
||||
coo.push(2, 0, 1);
|
||||
coo.push(2, 3, 2);
|
||||
coo.push(0, 1, 3);
|
||||
coo.push(2, 2, 1);
|
||||
coo
|
||||
};
|
||||
|
||||
let expected_csc = CscMatrix::try_from_csc_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 3, 5],
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![1, 5, 1, 4, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_coo_csc(&coo), expected_csc);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_csc_coo() {
|
||||
let csc = CscMatrix::try_from_csc_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 2, 3, 5],
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![1, 2, 1, 4, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let expected_coo = CooMatrix::try_from_triplets(
|
||||
3,
|
||||
4,
|
||||
vec![2, 0, 2, 1, 2],
|
||||
vec![0, 1, 2, 3, 3],
|
||||
vec![1, 2, 1, 4, 2],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_csc_coo(&csc), expected_coo);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_csr_csc_bidirectional() {
|
||||
let csr = CsrMatrix::try_from_csr_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 3, 4, 6],
|
||||
vec![1, 2, 3, 0, 1, 3],
|
||||
vec![5, 3, 2, 2, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let csc = CscMatrix::try_from_csc_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 3, 4, 6],
|
||||
vec![1, 0, 2, 0, 0, 2],
|
||||
vec![2, 5, 1, 3, 2, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(convert_csr_csc(&csr), csc);
|
||||
assert_eq!(convert_csc_csr(&csc), csr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_csr_dense_bidirectional() {
|
||||
let csr = CsrMatrix::try_from_csr_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 3, 4, 6],
|
||||
vec![1, 2, 3, 0, 1, 3],
|
||||
vec![5, 3, 2, 2, 1, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||
0, 5, 3, 2,
|
||||
2, 0, 0, 0,
|
||||
0, 1, 0, 4
|
||||
]);
|
||||
|
||||
assert_eq!(convert_csr_dense(&csr), dense);
|
||||
assert_eq!(convert_dense_csr(&dense), csr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_csc_dense_bidirectional() {
|
||||
let csc = CscMatrix::try_from_csc_data(
|
||||
3,
|
||||
4,
|
||||
vec![0, 1, 3, 4, 6],
|
||||
vec![1, 0, 2, 0, 0, 2],
|
||||
vec![2, 5, 1, 3, 2, 4],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||
0, 5, 3, 2,
|
||||
2, 0, 0, 0,
|
||||
0, 1, 0, 4
|
||||
]);
|
||||
|
||||
assert_eq!(convert_csc_dense(&csc), dense);
|
||||
assert_eq!(convert_dense_csc(&dense), csc);
|
||||
}
|
||||
|
||||
fn coo_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||
coo_with_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40, 2)
|
||||
}
|
||||
|
||||
fn coo_no_duplicates_strategy() -> impl Strategy<Value = CooMatrix<i32>> {
|
||||
coo_no_duplicates(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
fn csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||
csr(-5..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||
fn non_zero_csr_strategy() -> impl Strategy<Value = CsrMatrix<i32>> {
|
||||
csr(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
/// Avoid generating explicit zero values so that it is possible to reason about sparsity patterns
|
||||
fn non_zero_csc_strategy() -> impl Strategy<Value = CscMatrix<i32>> {
|
||||
csc(1..=5, 0..=6usize, 0..=6usize, 40)
|
||||
}
|
||||
|
||||
fn dense_strategy() -> impl Strategy<Value = DMatrix<i32>> {
|
||||
matrix(-5..=5, 0..=6, 0..=6)
|
||||
}
|
||||
|
||||
proptest! {
|
||||
|
||||
#[test]
|
||||
fn convert_dense_coo_roundtrip(dense in matrix(-5 ..= 5, 0 ..=6, 0..=6)) {
|
||||
let coo = convert_dense_coo(&dense);
|
||||
let dense2 = convert_coo_dense(&coo);
|
||||
prop_assert_eq!(&dense, &dense2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_coo_dense_coo_roundtrip(coo in coo_strategy()) {
|
||||
// We cannot compare the result of the roundtrip coo -> dense -> coo directly for
|
||||
// two reasons:
|
||||
// 1. the COO matrices will generally have different ordering of elements
|
||||
// 2. explicitly stored zero entries in the original matrix will be discarded
|
||||
// when converting back to COO
|
||||
// Therefore we instead compare the results of converting the COO matrix
|
||||
// at the end of the roundtrip with its dense representation
|
||||
let dense = convert_coo_dense(&coo);
|
||||
let coo2 = convert_dense_coo(&dense);
|
||||
let dense2 = convert_coo_dense(&coo2);
|
||||
prop_assert_eq!(dense, dense2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_from_dense_roundtrip(dense in dense_strategy()) {
|
||||
prop_assert_eq!(&dense, &DMatrix::from(&CooMatrix::from(&dense)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_coo_csr_agrees_with_csr_dense(coo in coo_strategy()) {
|
||||
let coo_dense = convert_coo_dense(&coo);
|
||||
let csr = convert_coo_csr(&coo);
|
||||
let csr_dense = convert_csr_dense(&csr);
|
||||
prop_assert_eq!(csr_dense, coo_dense);
|
||||
|
||||
// It might be that COO matrices have a higher nnz due to duplicates,
|
||||
// so we can only check that the CSR matrix has no more than the original COO matrix
|
||||
prop_assert!(csr.nnz() <= coo.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_coo_csr_nnz(coo in coo_no_duplicates_strategy()) {
|
||||
// Check that the NNZ are equal when converting from a CooMatrix without
|
||||
// duplicates to a CSR matrix
|
||||
let csr = convert_coo_csr(&coo);
|
||||
prop_assert_eq!(csr.nnz(), coo.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_csr_coo_roundtrip(csr in csr_strategy()) {
|
||||
let coo = convert_csr_coo(&csr);
|
||||
let csr2 = convert_coo_csr(&coo);
|
||||
prop_assert_eq!(csr2, csr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_from_csr_roundtrip(csr in csr_strategy()) {
|
||||
prop_assert_eq!(&csr, &CsrMatrix::from(&CooMatrix::from(&csr)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_from_dense_roundtrip(dense in dense_strategy()) {
|
||||
prop_assert_eq!(&dense, &DMatrix::from(&CsrMatrix::from(&dense)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_csr_dense_roundtrip(csr in non_zero_csr_strategy()) {
|
||||
// Since we only generate CSR matrices with non-zero values, we know that the
|
||||
// number of explicitly stored entries when converting CSR->Dense->CSR should be
|
||||
// unchanged, so that we can verify that the result is the same as the input
|
||||
let dense = convert_csr_dense(&csr);
|
||||
let csr2 = convert_dense_csr(&dense);
|
||||
prop_assert_eq!(csr2, csr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_csc_coo_roundtrip(csc in csc_strategy()) {
|
||||
let coo = convert_csc_coo(&csc);
|
||||
let csc2 = convert_coo_csc(&coo);
|
||||
prop_assert_eq!(csc2, csc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_from_csc_roundtrip(csc in csc_strategy()) {
|
||||
prop_assert_eq!(&csc, &CscMatrix::from(&CooMatrix::from(&csc)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_csc_dense_roundtrip(csc in non_zero_csc_strategy()) {
|
||||
// Since we only generate CSC matrices with non-zero values, we know that the
|
||||
// number of explicitly stored entries when converting CSC->Dense->CSC should be
|
||||
// unchanged, so that we can verify that the result is the same as the input
|
||||
let dense = convert_csc_dense(&csc);
|
||||
let csc2 = convert_dense_csc(&dense);
|
||||
prop_assert_eq!(csc2, csc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_from_dense_roundtrip(dense in dense_strategy()) {
|
||||
prop_assert_eq!(&dense, &DMatrix::from(&CscMatrix::from(&dense)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_coo_csc_agrees_with_csc_dense(coo in coo_strategy()) {
|
||||
let coo_dense = convert_coo_dense(&coo);
|
||||
let csc = convert_coo_csc(&coo);
|
||||
let csc_dense = convert_csc_dense(&csc);
|
||||
prop_assert_eq!(csc_dense, coo_dense);
|
||||
|
||||
// It might be that COO matrices have a higher nnz due to duplicates,
|
||||
// so we can only check that the CSR matrix has no more than the original COO matrix
|
||||
prop_assert!(csc.nnz() <= coo.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_coo_csc_nnz(coo in coo_no_duplicates_strategy()) {
|
||||
// Check that the NNZ are equal when converting from a CooMatrix without
|
||||
// duplicates to a CSR matrix
|
||||
let csc = convert_coo_csc(&coo);
|
||||
prop_assert_eq!(csc.nnz(), coo.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_csc_csr_roundtrip(csc in csc_strategy()) {
|
||||
let csr = convert_csc_csr(&csc);
|
||||
let csc2 = convert_csr_csc(&csr);
|
||||
prop_assert_eq!(csc2, csc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_csr_csc_roundtrip(csr in csr_strategy()) {
|
||||
let csc = convert_csr_csc(&csr);
|
||||
let csr2 = convert_csc_csr(&csc);
|
||||
prop_assert_eq!(csr2, csr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_from_csr_roundtrip(csr in csr_strategy()) {
|
||||
prop_assert_eq!(&csr, &CsrMatrix::from(&CscMatrix::from(&csr)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_from_csc_roundtrip(csc in csc_strategy()) {
|
||||
prop_assert_eq!(&csc, &CscMatrix::from(&CsrMatrix::from(&csc)));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,254 @@
|
|||
use crate::assert_panics;
|
||||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::coo::CooMatrix;
|
||||
use nalgebra_sparse::SparseFormatErrorKind;
|
||||
|
||||
#[test]
|
||||
fn coo_construction_for_valid_data() {
|
||||
// Test that construction with try_from_triplets succeeds, that the state of the
|
||||
// matrix afterwards is as expected, and that the dense representation matches expectations.
|
||||
|
||||
{
|
||||
// Zero matrix
|
||||
let coo =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 2, Vec::new(), Vec::new(), Vec::new()).unwrap();
|
||||
assert_eq!(coo.nrows(), 3);
|
||||
assert_eq!(coo.ncols(), 2);
|
||||
assert!(coo.triplet_iter().next().is_none());
|
||||
assert!(coo.row_indices().is_empty());
|
||||
assert!(coo.col_indices().is_empty());
|
||||
assert!(coo.values().is_empty());
|
||||
|
||||
assert_eq!(DMatrix::from(&coo), DMatrix::repeat(3, 2, 0));
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary matrix, no duplicates
|
||||
let i = vec![0, 1, 0, 0, 2];
|
||||
let j = vec![0, 2, 1, 3, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let coo =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||
assert_eq!(coo.nrows(), 3);
|
||||
assert_eq!(coo.ncols(), 5);
|
||||
|
||||
assert_eq!(i.as_slice(), coo.row_indices());
|
||||
assert_eq!(j.as_slice(), coo.col_indices());
|
||||
assert_eq!(v.as_slice(), coo.values());
|
||||
|
||||
let expected_triplets: Vec<_> = i
|
||||
.iter()
|
||||
.zip(&j)
|
||||
.zip(&v)
|
||||
.map(|((i, j), v)| (*i, *j, *v))
|
||||
.collect();
|
||||
let actual_triplets: Vec<_> = coo.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||
assert_eq!(actual_triplets, expected_triplets);
|
||||
|
||||
#[rustfmt::skip]
|
||||
let expected_dense = DMatrix::from_row_slice(3, 5, &[
|
||||
2, 7, 0, 3, 0,
|
||||
0, 0, 3, 0, 0,
|
||||
0, 0, 0, 1, 0
|
||||
]);
|
||||
assert_eq!(DMatrix::from(&coo), expected_dense);
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary matrix, with duplicates
|
||||
let i = vec![0, 1, 0, 0, 0, 0, 2, 1];
|
||||
let j = vec![0, 2, 0, 1, 0, 3, 3, 2];
|
||||
let v = vec![2, 3, 4, 7, 1, 3, 1, 5];
|
||||
let coo =
|
||||
CooMatrix::<i32>::try_from_triplets(3, 5, i.clone(), j.clone(), v.clone()).unwrap();
|
||||
assert_eq!(coo.nrows(), 3);
|
||||
assert_eq!(coo.ncols(), 5);
|
||||
|
||||
assert_eq!(i.as_slice(), coo.row_indices());
|
||||
assert_eq!(j.as_slice(), coo.col_indices());
|
||||
assert_eq!(v.as_slice(), coo.values());
|
||||
|
||||
let expected_triplets: Vec<_> = i
|
||||
.iter()
|
||||
.zip(&j)
|
||||
.zip(&v)
|
||||
.map(|((i, j), v)| (*i, *j, *v))
|
||||
.collect();
|
||||
let actual_triplets: Vec<_> = coo.triplet_iter().map(|(i, j, v)| (i, j, *v)).collect();
|
||||
assert_eq!(actual_triplets, expected_triplets);
|
||||
|
||||
#[rustfmt::skip]
|
||||
let expected_dense = DMatrix::from_row_slice(3, 5, &[
|
||||
7, 7, 0, 3, 0,
|
||||
0, 0, 8, 0, 0,
|
||||
0, 0, 0, 1, 0
|
||||
]);
|
||||
assert_eq!(DMatrix::from(&coo), expected_dense);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_try_from_triplets_reports_out_of_bounds_indices() {
|
||||
{
|
||||
// 0x0 matrix
|
||||
let result = CooMatrix::<i32>::try_from_triplets(0, 0, vec![0], vec![0], vec![2]);
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, row out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![0], vec![2]);
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, col out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![0], vec![1], vec![2]);
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x1 matrix, row and col out of bounds
|
||||
let result = CooMatrix::<i32>::try_from_triplets(1, 1, vec![1], vec![1], vec![2]);
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary matrix, row out of bounds
|
||||
let i = vec![0, 1, 0, 3, 2];
|
||||
let j = vec![0, 2, 1, 3, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary matrix, col out of bounds
|
||||
let i = vec![0, 1, 0, 0, 2];
|
||||
let j = vec![0, 2, 1, 5, 3];
|
||||
let v = vec![2, 3, 7, 3, 1];
|
||||
let result = CooMatrix::<i32>::try_from_triplets(3, 5, i, j, v);
|
||||
assert!(matches!(
|
||||
result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::IndexOutOfBounds
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_try_from_triplets_panics_on_mismatched_vectors() {
|
||||
// Check that try_from_triplets panics when the triplet vectors have different lengths
|
||||
macro_rules! assert_errs {
|
||||
($result:expr) => {
|
||||
assert!(matches!(
|
||||
$result.unwrap_err().kind(),
|
||||
SparseFormatErrorKind::InvalidStructure
|
||||
))
|
||||
};
|
||||
}
|
||||
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1, 2],
|
||||
vec![0],
|
||||
vec![0]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1],
|
||||
vec![0, 0],
|
||||
vec![0]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1],
|
||||
vec![0],
|
||||
vec![0, 1]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1, 2],
|
||||
vec![0, 1],
|
||||
vec![0]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1],
|
||||
vec![0, 1],
|
||||
vec![0, 1]
|
||||
));
|
||||
assert_errs!(CooMatrix::<i32>::try_from_triplets(
|
||||
3,
|
||||
5,
|
||||
vec![1, 1],
|
||||
vec![0],
|
||||
vec![0, 1]
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_push_valid_entries() {
|
||||
let mut coo = CooMatrix::new(3, 3);
|
||||
|
||||
coo.push(0, 0, 1);
|
||||
assert_eq!(coo.triplet_iter().collect::<Vec<_>>(), vec![(0, 0, &1)]);
|
||||
|
||||
coo.push(0, 0, 2);
|
||||
assert_eq!(
|
||||
coo.triplet_iter().collect::<Vec<_>>(),
|
||||
vec![(0, 0, &1), (0, 0, &2)]
|
||||
);
|
||||
|
||||
coo.push(2, 2, 3);
|
||||
assert_eq!(
|
||||
coo.triplet_iter().collect::<Vec<_>>(),
|
||||
vec![(0, 0, &1), (0, 0, &2), (2, 2, &3)]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coo_push_out_of_bounds_entries() {
|
||||
{
|
||||
// 0x0 matrix
|
||||
let coo = CooMatrix::new(0, 0);
|
||||
assert_panics!(coo.clone().push(0, 0, 1));
|
||||
}
|
||||
|
||||
{
|
||||
// 0x1 matrix
|
||||
assert_panics!(CooMatrix::new(0, 1).push(0, 0, 1));
|
||||
}
|
||||
|
||||
{
|
||||
// 1x0 matrix
|
||||
assert_panics!(CooMatrix::new(1, 0).push(0, 0, 1));
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary matrix dimensions
|
||||
let coo = CooMatrix::new(3, 2);
|
||||
assert_panics!(coo.clone().push(3, 0, 1));
|
||||
assert_panics!(coo.clone().push(2, 2, 1));
|
||||
assert_panics!(coo.clone().push(3, 2, 1));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
# Seeds for failure cases proptest has generated in the past. It is
|
||||
# automatically read and these particular cases re-run before any
|
||||
# novel cases are generated.
|
||||
#
|
||||
# It is recommended to check this file in to source control so that
|
||||
# everyone who runs the test benefits from these saved cases.
|
||||
cc a71b4654827840ed539b82cd7083615b0fb3f75933de6a7d91d8148a2bf34960 # shrinks to (csc, triplet_subset) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 1, 1, 1, 1, 1, 1], minor_indices: [0], minor_dim: 4 }, values: [0] } }, {})
|
|
@ -0,0 +1,605 @@
|
|||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::csc::CscMatrix;
|
||||
use nalgebra_sparse::{SparseEntry, SparseEntryMut, SparseFormatErrorKind};
|
||||
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::subsequence;
|
||||
|
||||
use crate::assert_panics;
|
||||
use crate::common::csc_strategy;
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn csc_matrix_valid_data() {
|
||||
// Construct matrix from valid data and check that selected methods return results
|
||||
// that agree with expectations.
|
||||
|
||||
{
|
||||
// A CSC matrix with zero explicitly stored entries
|
||||
let offsets = vec![0, 0, 0, 0];
|
||||
let indices = vec![];
|
||||
let values = Vec::<i32>::new();
|
||||
let mut matrix = CscMatrix::try_from_csc_data(2, 3, offsets, indices, values).unwrap();
|
||||
|
||||
assert_eq!(matrix, CscMatrix::zeros(2, 3));
|
||||
|
||||
assert_eq!(matrix.nrows(), 2);
|
||||
assert_eq!(matrix.ncols(), 3);
|
||||
assert_eq!(matrix.nnz(), 0);
|
||||
assert_eq!(matrix.col_offsets(), &[0, 0, 0, 0]);
|
||||
assert_eq!(matrix.row_indices(), &[]);
|
||||
assert_eq!(matrix.values(), &[]);
|
||||
|
||||
assert!(matrix.triplet_iter().next().is_none());
|
||||
assert!(matrix.triplet_iter_mut().next().is_none());
|
||||
|
||||
assert_eq!(matrix.col(0).nrows(), 2);
|
||||
assert_eq!(matrix.col(0).nnz(), 0);
|
||||
assert_eq!(matrix.col(0).row_indices(), &[]);
|
||||
assert_eq!(matrix.col(0).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(0).nrows(), 2);
|
||||
assert_eq!(matrix.col_mut(0).nnz(), 0);
|
||||
assert_eq!(matrix.col_mut(0).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(0).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(0).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.col_mut(0).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(1).nrows(), 2);
|
||||
assert_eq!(matrix.col(1).nnz(), 0);
|
||||
assert_eq!(matrix.col(1).row_indices(), &[]);
|
||||
assert_eq!(matrix.col(1).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).nrows(), 2);
|
||||
assert_eq!(matrix.col_mut(1).nnz(), 0);
|
||||
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.col_mut(1).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(2).nrows(), 2);
|
||||
assert_eq!(matrix.col(2).nnz(), 0);
|
||||
assert_eq!(matrix.col(2).row_indices(), &[]);
|
||||
assert_eq!(matrix.col(2).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(2).nrows(), 2);
|
||||
assert_eq!(matrix.col_mut(2).nnz(), 0);
|
||||
assert_eq!(matrix.col_mut(2).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(2).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(2).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.col_mut(2).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_col(3).is_none());
|
||||
assert!(matrix.get_col_mut(3).is_none());
|
||||
|
||||
let (offsets, indices, values) = matrix.disassemble();
|
||||
|
||||
assert_eq!(offsets, vec![0, 0, 0, 0]);
|
||||
assert_eq!(indices, vec![]);
|
||||
assert_eq!(values, vec![]);
|
||||
}
|
||||
|
||||
{
|
||||
// An arbitrary CSC matrix
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let mut matrix =
|
||||
CscMatrix::try_from_csc_data(6, 3, offsets.clone(), indices.clone(), values.clone())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(matrix.nrows(), 6);
|
||||
assert_eq!(matrix.ncols(), 3);
|
||||
assert_eq!(matrix.nnz(), 5);
|
||||
assert_eq!(matrix.col_offsets(), &[0, 2, 2, 5]);
|
||||
assert_eq!(matrix.row_indices(), &[0, 5, 1, 2, 3]);
|
||||
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||
|
||||
let expected_triplets = vec![(0, 0, 0), (5, 0, 1), (1, 2, 2), (2, 2, 3), (3, 2, 4)];
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter_mut()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(0).nrows(), 6);
|
||||
assert_eq!(matrix.col(0).nnz(), 2);
|
||||
assert_eq!(matrix.col(0).row_indices(), &[0, 5]);
|
||||
assert_eq!(matrix.col(0).values(), &[0, 1]);
|
||||
assert_eq!(matrix.col_mut(0).nrows(), 6);
|
||||
assert_eq!(matrix.col_mut(0).nnz(), 2);
|
||||
assert_eq!(matrix.col_mut(0).row_indices(), &[0, 5]);
|
||||
assert_eq!(matrix.col_mut(0).values(), &[0, 1]);
|
||||
assert_eq!(matrix.col_mut(0).values_mut(), &[0, 1]);
|
||||
assert_eq!(
|
||||
matrix.col_mut(0).rows_and_values_mut(),
|
||||
([0, 5].as_ref(), [0, 1].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(1).nrows(), 6);
|
||||
assert_eq!(matrix.col(1).nnz(), 0);
|
||||
assert_eq!(matrix.col(1).row_indices(), &[]);
|
||||
assert_eq!(matrix.col(1).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).nrows(), 6);
|
||||
assert_eq!(matrix.col_mut(1).nnz(), 0);
|
||||
assert_eq!(matrix.col_mut(1).row_indices(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.col_mut(1).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.col_mut(1).rows_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.col(2).nrows(), 6);
|
||||
assert_eq!(matrix.col(2).nnz(), 3);
|
||||
assert_eq!(matrix.col(2).row_indices(), &[1, 2, 3]);
|
||||
assert_eq!(matrix.col(2).values(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.col_mut(2).nrows(), 6);
|
||||
assert_eq!(matrix.col_mut(2).nnz(), 3);
|
||||
assert_eq!(matrix.col_mut(2).row_indices(), &[1, 2, 3]);
|
||||
assert_eq!(matrix.col_mut(2).values(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.col_mut(2).values_mut(), &[2, 3, 4]);
|
||||
assert_eq!(
|
||||
matrix.col_mut(2).rows_and_values_mut(),
|
||||
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_col(3).is_none());
|
||||
assert!(matrix.get_col_mut(3).is_none());
|
||||
|
||||
let (offsets2, indices2, values2) = matrix.disassemble();
|
||||
|
||||
assert_eq!(offsets2, offsets);
|
||||
assert_eq!(indices2, indices);
|
||||
assert_eq!(values2, values);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_matrix_try_from_invalid_csc_data() {
|
||||
{
|
||||
// Empty offset array (invalid length)
|
||||
let matrix = CscMatrix::try_from_csc_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Offset array invalid length for arbitrary data
|
||||
let offsets = vec![0, 3, 5];
|
||||
let indices = vec![0, 1, 2, 3, 5];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid first entry in offsets array
|
||||
let offsets = vec![1, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid last entry in offsets array
|
||||
let offsets = vec![0, 2, 2, 4];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid length of offsets array
|
||||
let offsets = vec![0, 2, 2];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Nonmonotonic offsets
|
||||
let offsets = vec![0, 3, 2, 5];
|
||||
let indices = vec![0, 1, 2, 3, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Nonmonotonic minor indices
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 2, 3, 1, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Minor index out of bounds
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 6, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::IndexOutOfBounds
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Duplicate entry
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 2, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::DuplicateEntry
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_disassemble_avoids_clone_when_owned() {
|
||||
// Test that disassemble avoids cloning the sparsity pattern when it holds the sole reference
|
||||
// to the pattern. We do so by checking that the pointer to the data is unchanged.
|
||||
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let offsets_ptr = offsets.as_ptr();
|
||||
let indices_ptr = indices.as_ptr();
|
||||
let values_ptr = values.as_ptr();
|
||||
let matrix = CscMatrix::try_from_csc_data(6, 3, offsets, indices, values).unwrap();
|
||||
|
||||
let (offsets, indices, values) = matrix.disassemble();
|
||||
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||
assert_eq!(indices.as_ptr(), indices_ptr);
|
||||
assert_eq!(values.as_ptr(), values_ptr);
|
||||
}
|
||||
|
||||
// Rustfmt makes this test much harder to read by expanding some of the one-liners to 4-liners,
|
||||
// so for now we skip rustfmt...
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn csc_matrix_get_index_entry() {
|
||||
// Test .get_entry(_mut) and .index_entry(_mut) methods
|
||||
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(2, 3, &[
|
||||
1, 0, 3,
|
||||
0, 5, 6
|
||||
]);
|
||||
let csc = CscMatrix::from(&dense);
|
||||
|
||||
assert_eq!(csc.get_entry(0, 0), Some(SparseEntry::NonZero(&1)));
|
||||
assert_eq!(csc.index_entry(0, 0), SparseEntry::NonZero(&1));
|
||||
assert_eq!(csc.get_entry(0, 1), Some(SparseEntry::Zero));
|
||||
assert_eq!(csc.index_entry(0, 1), SparseEntry::Zero);
|
||||
assert_eq!(csc.get_entry(0, 2), Some(SparseEntry::NonZero(&3)));
|
||||
assert_eq!(csc.index_entry(0, 2), SparseEntry::NonZero(&3));
|
||||
assert_eq!(csc.get_entry(1, 0), Some(SparseEntry::Zero));
|
||||
assert_eq!(csc.index_entry(1, 0), SparseEntry::Zero);
|
||||
assert_eq!(csc.get_entry(1, 1), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(csc.index_entry(1, 1), SparseEntry::NonZero(&5));
|
||||
assert_eq!(csc.get_entry(1, 2), Some(SparseEntry::NonZero(&6)));
|
||||
assert_eq!(csc.index_entry(1, 2), SparseEntry::NonZero(&6));
|
||||
|
||||
// Check some out of bounds with .get_entry
|
||||
assert_eq!(csc.get_entry(0, 3), None);
|
||||
assert_eq!(csc.get_entry(0, 4), None);
|
||||
assert_eq!(csc.get_entry(1, 3), None);
|
||||
assert_eq!(csc.get_entry(1, 4), None);
|
||||
assert_eq!(csc.get_entry(2, 0), None);
|
||||
assert_eq!(csc.get_entry(2, 1), None);
|
||||
assert_eq!(csc.get_entry(2, 2), None);
|
||||
assert_eq!(csc.get_entry(2, 3), None);
|
||||
assert_eq!(csc.get_entry(2, 4), None);
|
||||
|
||||
// Check that out of bounds with .index_entry panics
|
||||
assert_panics!(csc.index_entry(0, 3));
|
||||
assert_panics!(csc.index_entry(0, 4));
|
||||
assert_panics!(csc.index_entry(1, 3));
|
||||
assert_panics!(csc.index_entry(1, 4));
|
||||
assert_panics!(csc.index_entry(2, 0));
|
||||
assert_panics!(csc.index_entry(2, 1));
|
||||
assert_panics!(csc.index_entry(2, 2));
|
||||
assert_panics!(csc.index_entry(2, 3));
|
||||
assert_panics!(csc.index_entry(2, 4));
|
||||
|
||||
{
|
||||
// Check mutable versions of the above functions
|
||||
let mut csc = csc;
|
||||
|
||||
assert_eq!(csc.get_entry_mut(0, 0), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||
assert_eq!(csc.index_entry_mut(0, 0), SparseEntryMut::NonZero(&mut 1));
|
||||
assert_eq!(csc.get_entry_mut(0, 1), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(csc.index_entry_mut(0, 1), SparseEntryMut::Zero);
|
||||
assert_eq!(csc.get_entry_mut(0, 2), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||
assert_eq!(csc.index_entry_mut(0, 2), SparseEntryMut::NonZero(&mut 3));
|
||||
assert_eq!(csc.get_entry_mut(1, 0), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(csc.index_entry_mut(1, 0), SparseEntryMut::Zero);
|
||||
assert_eq!(csc.get_entry_mut(1, 1), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||
assert_eq!(csc.index_entry_mut(1, 1), SparseEntryMut::NonZero(&mut 5));
|
||||
assert_eq!(csc.get_entry_mut(1, 2), Some(SparseEntryMut::NonZero(&mut 6)));
|
||||
assert_eq!(csc.index_entry_mut(1, 2), SparseEntryMut::NonZero(&mut 6));
|
||||
|
||||
// Check some out of bounds with .get_entry_mut
|
||||
assert_eq!(csc.get_entry_mut(0, 3), None);
|
||||
assert_eq!(csc.get_entry_mut(0, 4), None);
|
||||
assert_eq!(csc.get_entry_mut(1, 3), None);
|
||||
assert_eq!(csc.get_entry_mut(1, 4), None);
|
||||
assert_eq!(csc.get_entry_mut(2, 0), None);
|
||||
assert_eq!(csc.get_entry_mut(2, 1), None);
|
||||
assert_eq!(csc.get_entry_mut(2, 2), None);
|
||||
assert_eq!(csc.get_entry_mut(2, 3), None);
|
||||
assert_eq!(csc.get_entry_mut(2, 4), None);
|
||||
|
||||
// Check that out of bounds with .index_entry_mut panics
|
||||
// Note: the cloning is necessary because a mutable reference is not UnwindSafe
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(0, 3); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(0, 4); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(1, 3); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(1, 4); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 0); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 1); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 2); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 3); });
|
||||
assert_panics!({ let mut csc = csc.clone(); csc.index_entry_mut(2, 4); });
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_matrix_col_iter() {
|
||||
// Note: this is the transpose of the matrix used for the similar csr_matrix_row_iter test
|
||||
// (this way the actual tests are almost identical, due to the transposed relationship
|
||||
// between CSR and CSC)
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(4, 3, &[
|
||||
0, 3, 0,
|
||||
1, 0, 4,
|
||||
2, 0, 0,
|
||||
0, 0, 5,
|
||||
]);
|
||||
let csc = CscMatrix::from(&dense);
|
||||
|
||||
// Immutable iterator
|
||||
{
|
||||
let mut col_iter = csc.col_iter();
|
||||
|
||||
{
|
||||
let col = col_iter.next().unwrap();
|
||||
assert_eq!(col.nrows(), 4);
|
||||
assert_eq!(col.nnz(), 2);
|
||||
assert_eq!(col.row_indices(), &[1, 2]);
|
||||
assert_eq!(col.values(), &[1, 2]);
|
||||
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let col = col_iter.next().unwrap();
|
||||
assert_eq!(col.nrows(), 4);
|
||||
assert_eq!(col.nnz(), 1);
|
||||
assert_eq!(col.row_indices(), &[0]);
|
||||
assert_eq!(col.values(), &[3]);
|
||||
assert_eq!(col.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||
assert_eq!(col.get_entry(1), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let col = col_iter.next().unwrap();
|
||||
assert_eq!(col.nrows(), 4);
|
||||
assert_eq!(col.nnz(), 2);
|
||||
assert_eq!(col.row_indices(), &[1, 3]);
|
||||
assert_eq!(col.values(), &[4, 5]);
|
||||
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
}
|
||||
|
||||
assert!(col_iter.next().is_none());
|
||||
}
|
||||
|
||||
// Mutable iterator
|
||||
{
|
||||
let mut csc = csc;
|
||||
let mut col_iter = csc.col_iter_mut();
|
||||
|
||||
{
|
||||
let mut col = col_iter.next().unwrap();
|
||||
assert_eq!(col.nrows(), 4);
|
||||
assert_eq!(col.nnz(), 2);
|
||||
assert_eq!(col.row_indices(), &[1, 2]);
|
||||
assert_eq!(col.values(), &[1, 2]);
|
||||
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||
assert_eq!(col.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
|
||||
assert_eq!(col.values_mut(), &mut [1, 2]);
|
||||
assert_eq!(
|
||||
col.rows_and_values_mut(),
|
||||
([1, 2].as_ref(), [1, 2].as_mut())
|
||||
);
|
||||
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let mut col = col_iter.next().unwrap();
|
||||
assert_eq!(col.nrows(), 4);
|
||||
assert_eq!(col.nnz(), 1);
|
||||
assert_eq!(col.row_indices(), &[0]);
|
||||
assert_eq!(col.values(), &[3]);
|
||||
assert_eq!(col.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||
assert_eq!(col.get_entry(1), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
|
||||
assert_eq!(col.values_mut(), &mut [3]);
|
||||
assert_eq!(col.rows_and_values_mut(), ([0].as_ref(), [3].as_mut()));
|
||||
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let mut col = col_iter.next().unwrap();
|
||||
assert_eq!(col.nrows(), 4);
|
||||
assert_eq!(col.nnz(), 2);
|
||||
assert_eq!(col.row_indices(), &[1, 3]);
|
||||
assert_eq!(col.values(), &[4, 5]);
|
||||
assert_eq!(col.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||
assert_eq!(col.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(col.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(col.get_entry(4), None);
|
||||
|
||||
assert_eq!(col.values_mut(), &mut [4, 5]);
|
||||
assert_eq!(
|
||||
col.rows_and_values_mut(),
|
||||
([1, 3].as_ref(), [4, 5].as_mut())
|
||||
);
|
||||
assert_eq!(col.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 4)));
|
||||
assert_eq!(col.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(col.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||
assert_eq!(col.get_entry_mut(4), None);
|
||||
}
|
||||
|
||||
assert!(col_iter.next().is_none());
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn csc_double_transpose_is_identity(csc in csc_strategy()) {
|
||||
prop_assert_eq!(csc.transpose().transpose(), csc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_transpose_agrees_with_dense(csc in csc_strategy()) {
|
||||
let dense_transpose = DMatrix::from(&csc).transpose();
|
||||
let csc_transpose = csc.transpose();
|
||||
prop_assert_eq!(dense_transpose, DMatrix::from(&csc_transpose));
|
||||
prop_assert_eq!(csc.nnz(), csc_transpose.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_filter(
|
||||
(csc, triplet_subset)
|
||||
in csc_strategy()
|
||||
.prop_flat_map(|matrix| {
|
||||
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
|
||||
let subset = subsequence(triplets, 0 ..= matrix.nnz())
|
||||
.prop_map(|triplet_subset| {
|
||||
let set: HashSet<_> = triplet_subset.into_iter().collect();
|
||||
set
|
||||
});
|
||||
(Just(matrix), subset)
|
||||
}))
|
||||
{
|
||||
// We generate a CscMatrix and a HashSet corresponding to a subset of the (i, j, v)
|
||||
// values in the matrix, which we use for filtering the matrix entries.
|
||||
// The resulting triplets in the filtered matrix must then be exactly equal to
|
||||
// the subset.
|
||||
let filtered = csc.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
|
||||
let filtered_triplets: HashSet<_> = filtered
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(filtered_triplets, triplet_subset);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_lower_triangle_agrees_with_dense(csc in csc_strategy()) {
|
||||
let csc_lower_triangle = csc.lower_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csc_lower_triangle), DMatrix::from(&csc).lower_triangle());
|
||||
prop_assert!(csc_lower_triangle.nnz() <= csc.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_upper_triangle_agrees_with_dense(csc in csc_strategy()) {
|
||||
let csc_upper_triangle = csc.upper_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csc_upper_triangle), DMatrix::from(&csc).upper_triangle());
|
||||
prop_assert!(csc_upper_triangle.nnz() <= csc.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_diagonal_as_csc(csc in csc_strategy()) {
|
||||
let d = csc.diagonal_as_csc();
|
||||
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
|
||||
let csc_diagonal_entries: HashSet<_> = csc
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.filter(|&(i, j, _)| i == j)
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(d_entries, csc_diagonal_entries);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csc_identity(n in 0 ..= 6usize) {
|
||||
let csc = CscMatrix::<i32>::identity(n);
|
||||
prop_assert_eq!(csc.nnz(), n);
|
||||
prop_assert_eq!(DMatrix::from(&csc), DMatrix::identity(n, n));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,601 @@
|
|||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use nalgebra_sparse::{SparseEntry, SparseEntryMut, SparseFormatErrorKind};
|
||||
|
||||
use proptest::prelude::*;
|
||||
use proptest::sample::subsequence;
|
||||
|
||||
use crate::assert_panics;
|
||||
use crate::common::csr_strategy;
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[test]
|
||||
fn csr_matrix_valid_data() {
|
||||
// Construct matrix from valid data and check that selected methods return results
|
||||
// that agree with expectations.
|
||||
|
||||
{
|
||||
// A CSR matrix with zero explicitly stored entries
|
||||
let offsets = vec![0, 0, 0, 0];
|
||||
let indices = vec![];
|
||||
let values = Vec::<i32>::new();
|
||||
let mut matrix = CsrMatrix::try_from_csr_data(3, 2, offsets, indices, values).unwrap();
|
||||
|
||||
assert_eq!(matrix, CsrMatrix::zeros(3, 2));
|
||||
|
||||
assert_eq!(matrix.nrows(), 3);
|
||||
assert_eq!(matrix.ncols(), 2);
|
||||
assert_eq!(matrix.nnz(), 0);
|
||||
assert_eq!(matrix.row_offsets(), &[0, 0, 0, 0]);
|
||||
assert_eq!(matrix.col_indices(), &[]);
|
||||
assert_eq!(matrix.values(), &[]);
|
||||
|
||||
assert!(matrix.triplet_iter().next().is_none());
|
||||
assert!(matrix.triplet_iter_mut().next().is_none());
|
||||
|
||||
assert_eq!(matrix.row(0).ncols(), 2);
|
||||
assert_eq!(matrix.row(0).nnz(), 0);
|
||||
assert_eq!(matrix.row(0).col_indices(), &[]);
|
||||
assert_eq!(matrix.row(0).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(0).ncols(), 2);
|
||||
assert_eq!(matrix.row_mut(0).nnz(), 0);
|
||||
assert_eq!(matrix.row_mut(0).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(0).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(0).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.row_mut(0).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(1).ncols(), 2);
|
||||
assert_eq!(matrix.row(1).nnz(), 0);
|
||||
assert_eq!(matrix.row(1).col_indices(), &[]);
|
||||
assert_eq!(matrix.row(1).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).ncols(), 2);
|
||||
assert_eq!(matrix.row_mut(1).nnz(), 0);
|
||||
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.row_mut(1).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(2).ncols(), 2);
|
||||
assert_eq!(matrix.row(2).nnz(), 0);
|
||||
assert_eq!(matrix.row(2).col_indices(), &[]);
|
||||
assert_eq!(matrix.row(2).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(2).ncols(), 2);
|
||||
assert_eq!(matrix.row_mut(2).nnz(), 0);
|
||||
assert_eq!(matrix.row_mut(2).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(2).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(2).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.row_mut(2).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_row(3).is_none());
|
||||
assert!(matrix.get_row_mut(3).is_none());
|
||||
|
||||
let (offsets, indices, values) = matrix.disassemble();
|
||||
|
||||
assert_eq!(offsets, vec![0, 0, 0, 0]);
|
||||
assert_eq!(indices, vec![]);
|
||||
assert_eq!(values, vec![]);
|
||||
}
|
||||
|
||||
{
|
||||
// An arbitrary CSR matrix
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let mut matrix =
|
||||
CsrMatrix::try_from_csr_data(3, 6, offsets.clone(), indices.clone(), values.clone())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(matrix.nrows(), 3);
|
||||
assert_eq!(matrix.ncols(), 6);
|
||||
assert_eq!(matrix.nnz(), 5);
|
||||
assert_eq!(matrix.row_offsets(), &[0, 2, 2, 5]);
|
||||
assert_eq!(matrix.col_indices(), &[0, 5, 1, 2, 3]);
|
||||
assert_eq!(matrix.values(), &[0, 1, 2, 3, 4]);
|
||||
|
||||
let expected_triplets = vec![(0, 0, 0), (0, 5, 1), (2, 1, 2), (2, 2, 3), (2, 3, 4)];
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
assert_eq!(
|
||||
matrix
|
||||
.triplet_iter_mut()
|
||||
.map(|(i, j, v)| (i, j, *v))
|
||||
.collect::<Vec<_>>(),
|
||||
expected_triplets
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(0).ncols(), 6);
|
||||
assert_eq!(matrix.row(0).nnz(), 2);
|
||||
assert_eq!(matrix.row(0).col_indices(), &[0, 5]);
|
||||
assert_eq!(matrix.row(0).values(), &[0, 1]);
|
||||
assert_eq!(matrix.row_mut(0).ncols(), 6);
|
||||
assert_eq!(matrix.row_mut(0).nnz(), 2);
|
||||
assert_eq!(matrix.row_mut(0).col_indices(), &[0, 5]);
|
||||
assert_eq!(matrix.row_mut(0).values(), &[0, 1]);
|
||||
assert_eq!(matrix.row_mut(0).values_mut(), &[0, 1]);
|
||||
assert_eq!(
|
||||
matrix.row_mut(0).cols_and_values_mut(),
|
||||
([0, 5].as_ref(), [0, 1].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(1).ncols(), 6);
|
||||
assert_eq!(matrix.row(1).nnz(), 0);
|
||||
assert_eq!(matrix.row(1).col_indices(), &[]);
|
||||
assert_eq!(matrix.row(1).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).ncols(), 6);
|
||||
assert_eq!(matrix.row_mut(1).nnz(), 0);
|
||||
assert_eq!(matrix.row_mut(1).col_indices(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values(), &[]);
|
||||
assert_eq!(matrix.row_mut(1).values_mut(), &[]);
|
||||
assert_eq!(
|
||||
matrix.row_mut(1).cols_and_values_mut(),
|
||||
([].as_ref(), [].as_mut())
|
||||
);
|
||||
|
||||
assert_eq!(matrix.row(2).ncols(), 6);
|
||||
assert_eq!(matrix.row(2).nnz(), 3);
|
||||
assert_eq!(matrix.row(2).col_indices(), &[1, 2, 3]);
|
||||
assert_eq!(matrix.row(2).values(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.row_mut(2).ncols(), 6);
|
||||
assert_eq!(matrix.row_mut(2).nnz(), 3);
|
||||
assert_eq!(matrix.row_mut(2).col_indices(), &[1, 2, 3]);
|
||||
assert_eq!(matrix.row_mut(2).values(), &[2, 3, 4]);
|
||||
assert_eq!(matrix.row_mut(2).values_mut(), &[2, 3, 4]);
|
||||
assert_eq!(
|
||||
matrix.row_mut(2).cols_and_values_mut(),
|
||||
([1, 2, 3].as_ref(), [2, 3, 4].as_mut())
|
||||
);
|
||||
|
||||
assert!(matrix.get_row(3).is_none());
|
||||
assert!(matrix.get_row_mut(3).is_none());
|
||||
|
||||
let (offsets2, indices2, values2) = matrix.disassemble();
|
||||
|
||||
assert_eq!(offsets2, offsets);
|
||||
assert_eq!(indices2, indices);
|
||||
assert_eq!(values2, values);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_matrix_try_from_invalid_csr_data() {
|
||||
{
|
||||
// Empty offset array (invalid length)
|
||||
let matrix = CsrMatrix::try_from_csr_data(0, 0, Vec::new(), Vec::new(), Vec::<u32>::new());
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Offset array invalid length for arbitrary data
|
||||
let offsets = vec![0, 3, 5];
|
||||
let indices = vec![0, 1, 2, 3, 5];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid first entry in offsets array
|
||||
let offsets = vec![1, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid last entry in offsets array
|
||||
let offsets = vec![0, 2, 2, 4];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid length of offsets array
|
||||
let offsets = vec![0, 2, 2];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Nonmonotonic offsets
|
||||
let offsets = vec![0, 3, 2, 5];
|
||||
let indices = vec![0, 1, 2, 3, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Nonmonotonic minor indices
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 2, 3, 1, 4];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::InvalidStructure
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Minor index out of bounds
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 6, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::IndexOutOfBounds
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Duplicate entry
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 2, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values);
|
||||
assert_eq!(
|
||||
matrix.unwrap_err().kind(),
|
||||
&SparseFormatErrorKind::DuplicateEntry
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_disassemble_avoids_clone_when_owned() {
|
||||
// Test that disassemble avoids cloning the sparsity pattern when it holds the sole reference
|
||||
// to the pattern. We do so by checking that the pointer to the data is unchanged.
|
||||
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let values = vec![0, 1, 2, 3, 4];
|
||||
let offsets_ptr = offsets.as_ptr();
|
||||
let indices_ptr = indices.as_ptr();
|
||||
let values_ptr = values.as_ptr();
|
||||
let matrix = CsrMatrix::try_from_csr_data(3, 6, offsets, indices, values).unwrap();
|
||||
|
||||
let (offsets, indices, values) = matrix.disassemble();
|
||||
assert_eq!(offsets.as_ptr(), offsets_ptr);
|
||||
assert_eq!(indices.as_ptr(), indices_ptr);
|
||||
assert_eq!(values.as_ptr(), values_ptr);
|
||||
}
|
||||
|
||||
// Rustfmt makes this test much harder to read by expanding some of the one-liners to 4-liners,
|
||||
// so for now we skip rustfmt...
|
||||
#[rustfmt::skip]
|
||||
#[test]
|
||||
fn csr_matrix_get_index_entry() {
|
||||
// Test .get_entry(_mut) and .index_entry(_mut) methods
|
||||
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(2, 3, &[
|
||||
1, 0, 3,
|
||||
0, 5, 6
|
||||
]);
|
||||
let csr = CsrMatrix::from(&dense);
|
||||
|
||||
assert_eq!(csr.get_entry(0, 0), Some(SparseEntry::NonZero(&1)));
|
||||
assert_eq!(csr.index_entry(0, 0), SparseEntry::NonZero(&1));
|
||||
assert_eq!(csr.get_entry(0, 1), Some(SparseEntry::Zero));
|
||||
assert_eq!(csr.index_entry(0, 1), SparseEntry::Zero);
|
||||
assert_eq!(csr.get_entry(0, 2), Some(SparseEntry::NonZero(&3)));
|
||||
assert_eq!(csr.index_entry(0, 2), SparseEntry::NonZero(&3));
|
||||
assert_eq!(csr.get_entry(1, 0), Some(SparseEntry::Zero));
|
||||
assert_eq!(csr.index_entry(1, 0), SparseEntry::Zero);
|
||||
assert_eq!(csr.get_entry(1, 1), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(csr.index_entry(1, 1), SparseEntry::NonZero(&5));
|
||||
assert_eq!(csr.get_entry(1, 2), Some(SparseEntry::NonZero(&6)));
|
||||
assert_eq!(csr.index_entry(1, 2), SparseEntry::NonZero(&6));
|
||||
|
||||
// Check some out of bounds with .get_entry
|
||||
assert_eq!(csr.get_entry(0, 3), None);
|
||||
assert_eq!(csr.get_entry(0, 4), None);
|
||||
assert_eq!(csr.get_entry(1, 3), None);
|
||||
assert_eq!(csr.get_entry(1, 4), None);
|
||||
assert_eq!(csr.get_entry(2, 0), None);
|
||||
assert_eq!(csr.get_entry(2, 1), None);
|
||||
assert_eq!(csr.get_entry(2, 2), None);
|
||||
assert_eq!(csr.get_entry(2, 3), None);
|
||||
assert_eq!(csr.get_entry(2, 4), None);
|
||||
|
||||
// Check that out of bounds with .index_entry panics
|
||||
assert_panics!(csr.index_entry(0, 3));
|
||||
assert_panics!(csr.index_entry(0, 4));
|
||||
assert_panics!(csr.index_entry(1, 3));
|
||||
assert_panics!(csr.index_entry(1, 4));
|
||||
assert_panics!(csr.index_entry(2, 0));
|
||||
assert_panics!(csr.index_entry(2, 1));
|
||||
assert_panics!(csr.index_entry(2, 2));
|
||||
assert_panics!(csr.index_entry(2, 3));
|
||||
assert_panics!(csr.index_entry(2, 4));
|
||||
|
||||
{
|
||||
// Check mutable versions of the above functions
|
||||
let mut csr = csr;
|
||||
|
||||
assert_eq!(csr.get_entry_mut(0, 0), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||
assert_eq!(csr.index_entry_mut(0, 0), SparseEntryMut::NonZero(&mut 1));
|
||||
assert_eq!(csr.get_entry_mut(0, 1), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(csr.index_entry_mut(0, 1), SparseEntryMut::Zero);
|
||||
assert_eq!(csr.get_entry_mut(0, 2), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||
assert_eq!(csr.index_entry_mut(0, 2), SparseEntryMut::NonZero(&mut 3));
|
||||
assert_eq!(csr.get_entry_mut(1, 0), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(csr.index_entry_mut(1, 0), SparseEntryMut::Zero);
|
||||
assert_eq!(csr.get_entry_mut(1, 1), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||
assert_eq!(csr.index_entry_mut(1, 1), SparseEntryMut::NonZero(&mut 5));
|
||||
assert_eq!(csr.get_entry_mut(1, 2), Some(SparseEntryMut::NonZero(&mut 6)));
|
||||
assert_eq!(csr.index_entry_mut(1, 2), SparseEntryMut::NonZero(&mut 6));
|
||||
|
||||
// Check some out of bounds with .get_entry_mut
|
||||
assert_eq!(csr.get_entry_mut(0, 3), None);
|
||||
assert_eq!(csr.get_entry_mut(0, 4), None);
|
||||
assert_eq!(csr.get_entry_mut(1, 3), None);
|
||||
assert_eq!(csr.get_entry_mut(1, 4), None);
|
||||
assert_eq!(csr.get_entry_mut(2, 0), None);
|
||||
assert_eq!(csr.get_entry_mut(2, 1), None);
|
||||
assert_eq!(csr.get_entry_mut(2, 2), None);
|
||||
assert_eq!(csr.get_entry_mut(2, 3), None);
|
||||
assert_eq!(csr.get_entry_mut(2, 4), None);
|
||||
|
||||
// Check that out of bounds with .index_entry_mut panics
|
||||
// Note: the cloning is necessary because a mutable reference is not UnwindSafe
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(0, 3); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(0, 4); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(1, 3); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(1, 4); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 0); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 1); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 2); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 3); });
|
||||
assert_panics!({ let mut csr = csr.clone(); csr.index_entry_mut(2, 4); });
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_matrix_row_iter() {
|
||||
#[rustfmt::skip]
|
||||
let dense = DMatrix::from_row_slice(3, 4, &[
|
||||
0, 1, 2, 0,
|
||||
3, 0, 0, 0,
|
||||
0, 4, 0, 5
|
||||
]);
|
||||
let csr = CsrMatrix::from(&dense);
|
||||
|
||||
// Immutable iterator
|
||||
{
|
||||
let mut row_iter = csr.row_iter();
|
||||
|
||||
{
|
||||
let row = row_iter.next().unwrap();
|
||||
assert_eq!(row.ncols(), 4);
|
||||
assert_eq!(row.nnz(), 2);
|
||||
assert_eq!(row.col_indices(), &[1, 2]);
|
||||
assert_eq!(row.values(), &[1, 2]);
|
||||
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let row = row_iter.next().unwrap();
|
||||
assert_eq!(row.ncols(), 4);
|
||||
assert_eq!(row.nnz(), 1);
|
||||
assert_eq!(row.col_indices(), &[0]);
|
||||
assert_eq!(row.values(), &[3]);
|
||||
assert_eq!(row.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||
assert_eq!(row.get_entry(1), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let row = row_iter.next().unwrap();
|
||||
assert_eq!(row.ncols(), 4);
|
||||
assert_eq!(row.nnz(), 2);
|
||||
assert_eq!(row.col_indices(), &[1, 3]);
|
||||
assert_eq!(row.values(), &[4, 5]);
|
||||
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
}
|
||||
|
||||
assert!(row_iter.next().is_none());
|
||||
}
|
||||
|
||||
// Mutable iterator
|
||||
{
|
||||
let mut csr = csr;
|
||||
let mut row_iter = csr.row_iter_mut();
|
||||
|
||||
{
|
||||
let mut row = row_iter.next().unwrap();
|
||||
assert_eq!(row.ncols(), 4);
|
||||
assert_eq!(row.nnz(), 2);
|
||||
assert_eq!(row.col_indices(), &[1, 2]);
|
||||
assert_eq!(row.values(), &[1, 2]);
|
||||
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&1)));
|
||||
assert_eq!(row.get_entry(2), Some(SparseEntry::NonZero(&2)));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
|
||||
assert_eq!(row.values_mut(), &mut [1, 2]);
|
||||
assert_eq!(
|
||||
row.cols_and_values_mut(),
|
||||
([1, 2].as_ref(), [1, 2].as_mut())
|
||||
);
|
||||
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 1)));
|
||||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::NonZero(&mut 2)));
|
||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let mut row = row_iter.next().unwrap();
|
||||
assert_eq!(row.ncols(), 4);
|
||||
assert_eq!(row.nnz(), 1);
|
||||
assert_eq!(row.col_indices(), &[0]);
|
||||
assert_eq!(row.values(), &[3]);
|
||||
assert_eq!(row.get_entry(0), Some(SparseEntry::NonZero(&3)));
|
||||
assert_eq!(row.get_entry(1), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
|
||||
assert_eq!(row.values_mut(), &mut [3]);
|
||||
assert_eq!(row.cols_and_values_mut(), ([0].as_ref(), [3].as_mut()));
|
||||
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::NonZero(&mut 3)));
|
||||
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(4), None);
|
||||
}
|
||||
|
||||
{
|
||||
let mut row = row_iter.next().unwrap();
|
||||
assert_eq!(row.ncols(), 4);
|
||||
assert_eq!(row.nnz(), 2);
|
||||
assert_eq!(row.col_indices(), &[1, 3]);
|
||||
assert_eq!(row.values(), &[4, 5]);
|
||||
assert_eq!(row.get_entry(0), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(1), Some(SparseEntry::NonZero(&4)));
|
||||
assert_eq!(row.get_entry(2), Some(SparseEntry::Zero));
|
||||
assert_eq!(row.get_entry(3), Some(SparseEntry::NonZero(&5)));
|
||||
assert_eq!(row.get_entry(4), None);
|
||||
|
||||
assert_eq!(row.values_mut(), &mut [4, 5]);
|
||||
assert_eq!(
|
||||
row.cols_and_values_mut(),
|
||||
([1, 3].as_ref(), [4, 5].as_mut())
|
||||
);
|
||||
assert_eq!(row.get_entry_mut(0), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(1), Some(SparseEntryMut::NonZero(&mut 4)));
|
||||
assert_eq!(row.get_entry_mut(2), Some(SparseEntryMut::Zero));
|
||||
assert_eq!(row.get_entry_mut(3), Some(SparseEntryMut::NonZero(&mut 5)));
|
||||
assert_eq!(row.get_entry_mut(4), None);
|
||||
}
|
||||
|
||||
assert!(row_iter.next().is_none());
|
||||
}
|
||||
}
|
||||
|
||||
proptest! {
|
||||
#[test]
|
||||
fn csr_double_transpose_is_identity(csr in csr_strategy()) {
|
||||
prop_assert_eq!(csr.transpose().transpose(), csr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_transpose_agrees_with_dense(csr in csr_strategy()) {
|
||||
let dense_transpose = DMatrix::from(&csr).transpose();
|
||||
let csr_transpose = csr.transpose();
|
||||
prop_assert_eq!(dense_transpose, DMatrix::from(&csr_transpose));
|
||||
prop_assert_eq!(csr.nnz(), csr_transpose.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_filter(
|
||||
(csr, triplet_subset)
|
||||
in csr_strategy()
|
||||
.prop_flat_map(|matrix| {
|
||||
let triplets: Vec<_> = matrix.triplet_iter().cloned_values().collect();
|
||||
let subset = subsequence(triplets, 0 ..= matrix.nnz())
|
||||
.prop_map(|triplet_subset| {
|
||||
let set: HashSet<_> = triplet_subset.into_iter().collect();
|
||||
set
|
||||
});
|
||||
(Just(matrix), subset)
|
||||
}))
|
||||
{
|
||||
// We generate a CsrMatrix and a HashSet corresponding to a subset of the (i, j, v)
|
||||
// values in the matrix, which we use for filtering the matrix entries.
|
||||
// The resulting triplets in the filtered matrix must then be exactly equal to
|
||||
// the subset.
|
||||
let filtered = csr.filter(|i, j, v| triplet_subset.contains(&(i, j, *v)));
|
||||
let filtered_triplets: HashSet<_> = filtered
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(filtered_triplets, triplet_subset);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_lower_triangle_agrees_with_dense(csr in csr_strategy()) {
|
||||
let csr_lower_triangle = csr.lower_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csr_lower_triangle), DMatrix::from(&csr).lower_triangle());
|
||||
prop_assert!(csr_lower_triangle.nnz() <= csr.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_upper_triangle_agrees_with_dense(csr in csr_strategy()) {
|
||||
let csr_upper_triangle = csr.upper_triangle();
|
||||
prop_assert_eq!(DMatrix::from(&csr_upper_triangle), DMatrix::from(&csr).upper_triangle());
|
||||
prop_assert!(csr_upper_triangle.nnz() <= csr.nnz());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_diagonal_as_csr(csr in csr_strategy()) {
|
||||
let d = csr.diagonal_as_csr();
|
||||
let d_entries: HashSet<_> = d.triplet_iter().cloned_values().collect();
|
||||
let csr_diagonal_entries: HashSet<_> = csr
|
||||
.triplet_iter()
|
||||
.cloned_values()
|
||||
.filter(|&(i, j, _)| i == j)
|
||||
.collect();
|
||||
|
||||
prop_assert_eq!(d_entries, csr_diagonal_entries);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_identity(n in 0 ..= 6usize) {
|
||||
let csr = CsrMatrix::<i32>::identity(n);
|
||||
prop_assert_eq!(csr.nnz(), n);
|
||||
prop_assert_eq!(DMatrix::from(&csr), DMatrix::identity(n, n));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,8 @@
|
|||
mod cholesky;
|
||||
mod convert_serial;
|
||||
mod coo;
|
||||
mod csc;
|
||||
mod csr;
|
||||
mod ops;
|
||||
mod pattern;
|
||||
mod proptest;
|
|
@ -0,0 +1,14 @@
|
|||
# Seeds for failure cases proptest has generated in the past. It is
|
||||
# automatically read and these particular cases re-run before any
|
||||
# novel cases are generated.
|
||||
#
|
||||
# It is recommended to check this file in to source control so that
|
||||
# everyone who runs the test benefits from these saved cases.
|
||||
cc 6748ea4ac9523fcc4dd8327b27c6818f8df10eb2042774f59a6e3fa3205dbcbd # shrinks to (beta, alpha, (c, a, b)) = (0, -1, (Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 1, 5, -4, 2], nrows: Dynamic { value: 3 }, ncols: Dynamic { value: 3 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 2, 2, 2], minor_indices: [0, 1], minor_dim: 5 }, values: [-5, -2] }, Matrix { data: VecStorage { data: [4, -2, -3, -3, -5, 3, 5, 1, -4, -4, 3, 5, 5, 5, -3], nrows: Dynamic { value: 5 }, ncols: Dynamic { value: 3 } } }))
|
||||
cc dcf67ab7b8febf109cfa58ee0f082b9f7c23d6ad0df2e28dc99984deeb6b113a # shrinks to (beta, alpha, (c, a, b)) = (0, 0, (Matrix { data: VecStorage { data: [0, -1], nrows: Dynamic { value: 1 }, ncols: Dynamic { value: 2 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0], minor_indices: [], minor_dim: 4 }, values: [] }, Matrix { data: VecStorage { data: [3, 1, 1, 0, 0, 3, -5, -3], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 2 } } }))
|
||||
cc dbaef9886eaad28be7cd48326b857f039d695bc0b19e9ada3304e812e984d2c3 # shrinks to (beta, alpha, (c, a, b)) = (0, -1, (Matrix { data: VecStorage { data: [1], nrows: Dynamic { value: 1 }, ncols: Dynamic { value: 1 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0], minor_indices: [], minor_dim: 0 }, values: [] }, Matrix { data: VecStorage { data: [], nrows: Dynamic { value: 0 }, ncols: Dynamic { value: 1 } } }))
|
||||
cc 99e312beb498ffa79194f41501ea312dce1911878eba131282904ac97205aaa9 # shrinks to SpmmCsrDenseArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrDenseArgs { c: Matrix { data: VecStorage { data: [-1, 4, -1, -4, 2, 1, 4, -2, 1, 3, -2, 5], nrows: Dynamic { value: 2 }, ncols: Dynamic { value: 6 } } }, beta: 0, alpha: 0, trans_a: Transpose, a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 1, 1, 1, 1, 1, 1], minor_indices: [0], minor_dim: 2 }, values: [0] }, trans_b: Transpose, b: Matrix { data: VecStorage { data: [-1, 1, 0, -5, 4, -5, 2, 2, 4, -4, -3, -1, 1, -1, 0, 1, -3, 4, -5, 0, 1, -5, 0, 1, 1, -3, 5, 3, 5, -3, -5, 3, -1, -4, -4, -3], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 6 } } } }
|
||||
cc bf74259df2db6eda24eb42098e57ea1c604bb67d6d0023fa308c321027b53a43 # shrinks to (alpha, beta, c, a, b, trans_a, trans_b) = (0, 0, Matrix { data: VecStorage { data: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 3, 6, 9, 12], minor_indices: [0, 1, 3, 1, 2, 3, 0, 1, 2, 1, 2, 3], minor_dim: 4 }, values: [-3, 3, -3, 1, -3, 0, 2, 1, 3, 0, -4, -1] }, Matrix { data: VecStorage { data: [3, 1, 4, -5, 5, -2, -5, -1, 1, -1, 3, -3, -2, 4, 2, -1, -1, 3, -5, 5], nrows: Dynamic { value: 4 }, ncols: Dynamic { value: 5 } } }, NoTranspose, NoTranspose)
|
||||
cc cbd6dac45a2f610e10cf4c15d4614cdbf7dfedbfcd733e4cc65c2e79829d14b3 # shrinks to SpmmCsrArgs { c, beta, alpha, trans_a, a, trans_b, b } = SpmmCsrArgs { c: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 1, 1, 1, 1], minor_indices: [0], minor_dim: 1 }, values: [0] }, beta: 0, alpha: 1, trans_a: Transpose(true), a: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 0, 0, 1, 1, 1], minor_indices: [1], minor_dim: 5 }, values: [-1] }, trans_b: Transpose(true), b: CsrMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 2], minor_indices: [2, 4], minor_dim: 5 }, values: [-1, 0] } }
|
||||
cc 8af78e2e41087743c8696c4d5563d59464f284662ccf85efc81ac56747d528bb # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 6, 12, 18, 24, 30, 33], minor_indices: [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 1, 2, 5], minor_dim: 6 }, values: [0.4566433975117654, -0.5109683327713039, 0.0, -3.276901622678194, 0.0, -2.2065487385437095, 0.0, -0.42643054427847016, -2.9232369281581234, 0.0, 1.2913925579441763, 0.0, -1.4073766622090917, -4.795473113569459, 4.681765156869446, -0.821162215887913, 3.0315816068414794, -3.3986924718213407, -3.498903007282241, -3.1488953408335236, 3.458104636152161, -4.774694888508124, 2.603884664757498, 0.0, 0.0, -3.2650988857765535, 4.26699442646613, 0.0, -0.012223422086023561, 3.6899095325779285, -1.4264458042247958, 0.0, 3.4849193883471266] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -1.9513896933988457, -4.426942420881461, 0.0, 0.0, 0.0, -0.28264084049240257], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 2 } } })
|
||||
cc a4effd988fe352146fca365875e108ecf4f7d41f6ad54683e923ca6ce712e5d0 # shrinks to (a, b) = (CscMatrix { cs: CsMatrix { sparsity_pattern: SparsityPattern { major_offsets: [0, 5, 11, 17, 22, 27, 31], minor_indices: [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 3, 4, 5, 1, 2, 3, 4, 5, 0, 1, 3, 5], minor_dim: 6 }, values: [-2.24935510943371, -2.2288203680206227, 0.0, -1.029740125494273, 0.0, 0.0, 0.22632926934348507, -0.9123245943877407, 0.0, 3.8564332876991827, 0.0, 0.0, 0.0, -0.8235065737081717, 1.9337984046721566, 0.11003468246027737, -3.422112890579867, -3.7824068893569196, 0.0, -0.021700572247226546, -4.914783069982362, 0.6227245544506541, 0.0, 0.0, -4.411368879922364, -0.00013623178651567258, -2.613658177661417, -2.2783292441548637, 0.0, 1.351859435890189, -0.021345159183605134] } }, Matrix { data: VecStorage { data: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -4.519417607973404, 0.0, 0.0, 0.0, -0.21238483334481817], nrows: Dynamic { value: 6 }, ncols: Dynamic { value: 3 } } })
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,154 @@
|
|||
use nalgebra_sparse::pattern::{SparsityPattern, SparsityPatternFormatError};
|
||||
|
||||
#[test]
|
||||
fn sparsity_pattern_valid_data() {
|
||||
// Construct pattern from valid data and check that selected methods return results
|
||||
// that agree with expectations.
|
||||
|
||||
{
|
||||
// A pattern with zero explicitly stored entries
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(3, 2, vec![0, 0, 0, 0], Vec::new())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(pattern.major_dim(), 3);
|
||||
assert_eq!(pattern.minor_dim(), 2);
|
||||
assert_eq!(pattern.nnz(), 0);
|
||||
assert_eq!(pattern.major_offsets(), &[0, 0, 0, 0]);
|
||||
assert_eq!(pattern.minor_indices(), &[]);
|
||||
assert_eq!(pattern.lane(0), &[]);
|
||||
assert_eq!(pattern.lane(1), &[]);
|
||||
assert_eq!(pattern.lane(2), &[]);
|
||||
assert!(pattern.entries().next().is_none());
|
||||
|
||||
assert_eq!(pattern, SparsityPattern::zeros(3, 2));
|
||||
|
||||
let (offsets, indices) = pattern.disassemble();
|
||||
assert_eq!(offsets, vec![0, 0, 0, 0]);
|
||||
assert_eq!(indices, vec![]);
|
||||
}
|
||||
|
||||
{
|
||||
// Arbitrary pattern
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern =
|
||||
SparsityPattern::try_from_offsets_and_indices(3, 6, offsets.clone(), indices.clone())
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(pattern.major_dim(), 3);
|
||||
assert_eq!(pattern.minor_dim(), 6);
|
||||
assert_eq!(pattern.major_offsets(), offsets.as_slice());
|
||||
assert_eq!(pattern.minor_indices(), indices.as_slice());
|
||||
assert_eq!(pattern.nnz(), 5);
|
||||
assert_eq!(pattern.lane(0), &[0, 5]);
|
||||
assert_eq!(pattern.lane(1), &[]);
|
||||
assert_eq!(pattern.lane(2), &[1, 2, 3]);
|
||||
assert_eq!(
|
||||
pattern.entries().collect::<Vec<_>>(),
|
||||
vec![(0, 0), (0, 5), (2, 1), (2, 2), (2, 3)]
|
||||
);
|
||||
|
||||
let (offsets2, indices2) = pattern.disassemble();
|
||||
assert_eq!(offsets2, offsets);
|
||||
assert_eq!(indices2, indices);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparsity_pattern_try_from_invalid_data() {
|
||||
{
|
||||
// Empty offset array (invalid length)
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(0, 0, Vec::new(), Vec::new());
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Offset array invalid length for arbitrary data
|
||||
let offsets = vec![0, 3, 5];
|
||||
let indices = vec![0, 1, 2, 3, 5];
|
||||
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid first entry in offsets array
|
||||
let offsets = vec![1, 2, 2, 5];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid last entry in offsets array
|
||||
let offsets = vec![0, 2, 2, 4];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetFirstLast)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// Invalid length of offsets array
|
||||
let offsets = vec![0, 2, 2];
|
||||
let indices = vec![0, 5, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert!(matches!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::InvalidOffsetArrayLength)
|
||||
));
|
||||
}
|
||||
|
||||
{
|
||||
// Nonmonotonic offsets
|
||||
let offsets = vec![0, 3, 2, 5];
|
||||
let indices = vec![0, 1, 2, 3, 4];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::NonmonotonicOffsets)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Nonmonotonic minor indices
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 2, 3, 1, 4];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::NonmonotonicMinorIndices)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Minor index out of bounds
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 6, 1, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert_eq!(
|
||||
pattern,
|
||||
Err(SparsityPatternFormatError::MinorIndexOutOfBounds)
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
// Duplicate entry
|
||||
let offsets = vec![0, 2, 2, 5];
|
||||
let indices = vec![0, 5, 2, 2, 3];
|
||||
let pattern = SparsityPattern::try_from_offsets_and_indices(3, 6, offsets, indices);
|
||||
assert_eq!(pattern, Err(SparsityPatternFormatError::DuplicateEntry));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
#[test]
|
||||
#[ignore]
|
||||
fn coo_no_duplicates_generates_admissible_matrices() {
|
||||
//TODO
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
mod slow {
|
||||
use nalgebra::DMatrix;
|
||||
use nalgebra_sparse::proptest::{
|
||||
coo_no_duplicates, coo_with_duplicates, csc, csr, sparsity_pattern,
|
||||
};
|
||||
|
||||
use itertools::Itertools;
|
||||
use proptest::strategy::ValueTree;
|
||||
use proptest::test_runner::TestRunner;
|
||||
|
||||
use proptest::prelude::*;
|
||||
|
||||
use nalgebra_sparse::csr::CsrMatrix;
|
||||
use std::collections::HashSet;
|
||||
use std::iter::repeat;
|
||||
use std::ops::RangeInclusive;
|
||||
|
||||
fn generate_all_possible_matrices(
|
||||
value_range: RangeInclusive<i32>,
|
||||
rows_range: RangeInclusive<usize>,
|
||||
cols_range: RangeInclusive<usize>,
|
||||
) -> HashSet<DMatrix<i32>> {
|
||||
// Enumerate all possible combinations
|
||||
let mut all_combinations = HashSet::new();
|
||||
for nrows in rows_range {
|
||||
for ncols in cols_range.clone() {
|
||||
// For the given number of rows and columns
|
||||
let n_values = nrows * ncols;
|
||||
|
||||
if n_values == 0 {
|
||||
// If we have zero rows or columns, the set of matrices with the given
|
||||
// rows and columns is a single element: an empty matrix
|
||||
all_combinations.insert(DMatrix::from_row_slice(nrows, ncols, &[]));
|
||||
} else {
|
||||
// Otherwise, we need to sample all possible matrices.
|
||||
// To do this, we generate the values as the (multi) Cartesian product
|
||||
// of the value sets. For example, for a 2x2 matrices, we consider
|
||||
// all possible 4-element arrays that the matrices can take by
|
||||
// considering all elements in the cartesian product
|
||||
// V x V x V x V
|
||||
// where V is the set of eligible values, e.g. V := -1 ..= 1
|
||||
let values_iter = repeat(value_range.clone())
|
||||
.take(n_values)
|
||||
.multi_cartesian_product();
|
||||
for matrix_values in values_iter {
|
||||
all_combinations.insert(DMatrix::from_row_slice(
|
||||
nrows,
|
||||
ncols,
|
||||
&matrix_values,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
all_combinations
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
#[test]
|
||||
fn coo_no_duplicates_samples_all_admissible_outputs() {
|
||||
// Note: This test basically mirrors a similar test for `matrix` in the `nalgebra` repo.
|
||||
|
||||
// Test that the proptest generation covers all possible outputs for a small space of inputs
|
||||
// given enough samples.
|
||||
|
||||
// We use a deterministic test runner to make the test "stable".
|
||||
let mut runner = TestRunner::deterministic();
|
||||
|
||||
// This number needs to be high enough so that we with high probability sample
|
||||
// all possible cases
|
||||
let num_generated_matrices = 500000;
|
||||
|
||||
let values = -1..=1;
|
||||
let rows = 0..=2;
|
||||
let cols = 0..=3;
|
||||
let max_nnz = rows.end() * cols.end();
|
||||
let strategy = coo_no_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||
|
||||
// Enumerate all possible combinations
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||
assert_eq!(
|
||||
visited_combinations, all_combinations,
|
||||
"Did not sample all possible values."
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
#[test]
|
||||
fn coo_with_duplicates_samples_all_admissible_outputs() {
|
||||
// This is almost the same as the test for coo_no_duplicates, except that we need
|
||||
// a different "success" criterion, since coo_with_duplicates is able to generate
|
||||
// matrices with values outside of the value constraints. See below for details.
|
||||
|
||||
// We use a deterministic test runner to make the test "stable".
|
||||
let mut runner = TestRunner::deterministic();
|
||||
|
||||
// This number needs to be high enough so that we with high probability sample
|
||||
// all possible cases
|
||||
let num_generated_matrices = 500000;
|
||||
|
||||
let values = -1..=1;
|
||||
let rows = 0..=2;
|
||||
let cols = 0..=3;
|
||||
let max_nnz = rows.end() * cols.end();
|
||||
let strategy = coo_with_duplicates(values.clone(), rows.clone(), cols.clone(), max_nnz, 2);
|
||||
|
||||
// Enumerate all possible combinations that fit the constraints
|
||||
// (note: this is only a subset of the matrices that can be generated by
|
||||
// `coo_with_duplicates`)
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
// Here we cannot verify that the set of visited combinations is *equal* to
|
||||
// all possible outcomes with the given constraints, however the
|
||||
// strategy should be able to generate all matrices that fit the constraints.
|
||||
// In other words, we need to determine that set of all admissible matrices
|
||||
// is contained in the set of visited matrices
|
||||
assert!(all_combinations.is_subset(&visited_combinations));
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
#[test]
|
||||
fn csr_samples_all_admissible_outputs() {
|
||||
// We use a deterministic test runner to make the test "stable".
|
||||
let mut runner = TestRunner::deterministic();
|
||||
|
||||
// This number needs to be high enough so that we with high probability sample
|
||||
// all possible cases
|
||||
let num_generated_matrices = 500000;
|
||||
|
||||
let values = -1..=1;
|
||||
let rows = 0..=2;
|
||||
let cols = 0..=3;
|
||||
let max_nnz = rows.end() * cols.end();
|
||||
let strategy = csr(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||
assert_eq!(
|
||||
visited_combinations, all_combinations,
|
||||
"Did not sample all possible values."
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
#[test]
|
||||
fn csc_samples_all_admissible_outputs() {
|
||||
// We use a deterministic test runner to make the test "stable".
|
||||
let mut runner = TestRunner::deterministic();
|
||||
|
||||
// This number needs to be high enough so that we with high probability sample
|
||||
// all possible cases
|
||||
let num_generated_matrices = 500000;
|
||||
|
||||
let values = -1..=1;
|
||||
let rows = 0..=2;
|
||||
let cols = 0..=3;
|
||||
let max_nnz = rows.end() * cols.end();
|
||||
let strategy = csc(values.clone(), rows.clone(), cols.clone(), max_nnz);
|
||||
|
||||
let all_combinations = generate_all_possible_matrices(values, rows, cols);
|
||||
|
||||
let visited_combinations =
|
||||
sample_matrix_output_space(strategy, &mut runner, num_generated_matrices);
|
||||
|
||||
assert_eq!(visited_combinations.len(), all_combinations.len());
|
||||
assert_eq!(
|
||||
visited_combinations, all_combinations,
|
||||
"Did not sample all possible values."
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "slow-tests")]
|
||||
#[test]
|
||||
fn sparsity_pattern_samples_all_admissible_outputs() {
|
||||
let mut runner = TestRunner::deterministic();
|
||||
|
||||
let num_generated_patterns = 50000;
|
||||
|
||||
let major_dims = 0..=2;
|
||||
let minor_dims = 0..=3;
|
||||
let max_nnz = major_dims.end() * minor_dims.end();
|
||||
let strategy = sparsity_pattern(major_dims.clone(), minor_dims.clone(), max_nnz);
|
||||
|
||||
let visited_patterns: HashSet<_> = sample_strategy(strategy, &mut runner)
|
||||
.take(num_generated_patterns)
|
||||
.map(|pattern| {
|
||||
// We represent patterns as dense matrices with 1 if an entry is occupied,
|
||||
// 0 otherwise
|
||||
let values = vec![1; pattern.nnz()];
|
||||
CsrMatrix::try_from_pattern_and_values(pattern, values).unwrap()
|
||||
})
|
||||
.map(|csr| DMatrix::from(&csr))
|
||||
.collect();
|
||||
|
||||
let all_possible_patterns = generate_all_possible_matrices(0..=1, major_dims, minor_dims);
|
||||
|
||||
assert_eq!(visited_patterns.len(), all_possible_patterns.len());
|
||||
assert_eq!(visited_patterns, all_possible_patterns);
|
||||
}
|
||||
|
||||
fn sample_matrix_output_space<S>(
|
||||
strategy: S,
|
||||
runner: &mut TestRunner,
|
||||
num_samples: usize,
|
||||
) -> HashSet<DMatrix<i32>>
|
||||
where
|
||||
S: Strategy,
|
||||
DMatrix<i32>: for<'b> From<&'b S::Value>,
|
||||
{
|
||||
sample_strategy(strategy, runner)
|
||||
.take(num_samples)
|
||||
.map(|matrix| DMatrix::from(&matrix))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sample_strategy<'a, S: 'a + Strategy>(
|
||||
strategy: S,
|
||||
runner: &'a mut TestRunner,
|
||||
) -> impl 'a + Iterator<Item = S::Value> {
|
||||
repeat(()).map(move |_| {
|
||||
let tree = strategy
|
||||
.new_tree(runner)
|
||||
.expect("Tree generation should not fail");
|
||||
let value = tree.current();
|
||||
value
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
//! Abstract definition of a matrix data storage allocator.
|
||||
|
||||
use std::any::Any;
|
||||
use std::mem;
|
||||
|
||||
use crate::base::constraint::{SameNumberOfColumns, SameNumberOfRows, ShapeConstraint};
|
||||
use crate::base::dimension::{Dim, U1};
|
||||
|
@ -21,7 +22,7 @@ pub trait Allocator<N: Scalar, R: Dim, C: Dim = U1>: Any + Sized {
|
|||
type Buffer: ContiguousStorageMut<N, R, C> + Clone;
|
||||
|
||||
/// Allocates a buffer with the given number of rows and columns without initializing its content.
|
||||
unsafe fn allocate_uninitialized(nrows: R, ncols: C) -> Self::Buffer;
|
||||
unsafe fn allocate_uninitialized(nrows: R, ncols: C) -> mem::MaybeUninit<Self::Buffer>;
|
||||
|
||||
/// Allocates a buffer initialized with the content of the given iterator.
|
||||
fn allocate_from_iterator<I: IntoIterator<Item = N>>(
|
||||
|
|
|
@ -394,6 +394,26 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar + bytemuck::Zeroable, R: DimName, C: DimName> bytemuck::Zeroable
|
||||
for ArrayStorage<N, R, C>
|
||||
where
|
||||
R::Value: Mul<C::Value>,
|
||||
Prod<R::Value, C::Value>: ArrayLength<N>,
|
||||
Self: Copy,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar + bytemuck::Pod, R: DimName, C: DimName> bytemuck::Pod
|
||||
for ArrayStorage<N, R, C>
|
||||
where
|
||||
R::Value: Mul<C::Value>,
|
||||
Prod<R::Value, C::Value>: ArrayLength<N>,
|
||||
Self: Copy,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "abomonation-serialize")]
|
||||
impl<N, R, C> Abomonation for ArrayStorage<N, R, C>
|
||||
where
|
||||
|
|
|
@ -1328,7 +1328,8 @@ where
|
|||
ShapeConstraint: DimEq<D1, D1> + DimEq<D1, R3> + DimEq<C3, D4>,
|
||||
DefaultAllocator: Allocator<N, D1>,
|
||||
{
|
||||
let mut work = unsafe { Vector::new_uninitialized_generic(self.data.shape().0, U1) };
|
||||
let mut work =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, U1) };
|
||||
self.quadform_tr_with_workspace(&mut work, alpha, lhs, mid, beta)
|
||||
}
|
||||
|
||||
|
@ -1421,7 +1422,8 @@ where
|
|||
ShapeConstraint: DimEq<D2, R3> + DimEq<D1, C3> + AreMultipliable<C3, R3, D2, U1>,
|
||||
DefaultAllocator: Allocator<N, D2>,
|
||||
{
|
||||
let mut work = unsafe { Vector::new_uninitialized_generic(mid.data.shape().0, U1) };
|
||||
let mut work =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(mid.data.shape().0, U1) };
|
||||
self.quadform_with_workspace(&mut work, alpha, mid, rhs, beta)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ use rand::Rng;
|
|||
#[cfg(feature = "std")]
|
||||
use rand_distr::StandardNormal;
|
||||
use std::iter;
|
||||
use std::mem;
|
||||
use typenum::{self, Cmp, Greater};
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
|
@ -25,6 +26,23 @@ use crate::base::dimension::{Dim, DimName, Dynamic, U1, U2, U3, U4, U5, U6};
|
|||
use crate::base::storage::Storage;
|
||||
use crate::base::{DefaultAllocator, Matrix, MatrixMN, MatrixN, Scalar, Unit, Vector, VectorN};
|
||||
|
||||
/// When "no_unsound_assume_init" is enabled, expands to `unimplemented!()` instead of `new_uninitialized_generic().assume_init()`.
|
||||
/// Intended as a placeholder, each callsite should be refactored to use uninitialized memory soundly
|
||||
#[macro_export]
|
||||
macro_rules! unimplemented_or_uninitialized_generic {
|
||||
($nrows:expr, $ncols:expr) => {{
|
||||
#[cfg(feature="no_unsound_assume_init")] {
|
||||
// Some of the call sites need the number of rows and columns from this to infer a type, so
|
||||
// uninitialized memory is used to infer the type, as `N: Zero` isn't available at all callsites.
|
||||
// This may technically still be UB even though the assume_init is dead code, but all callsites should be fixed before #556 is closed.
|
||||
let typeinference_helper = crate::base::Matrix::new_uninitialized_generic($nrows, $ncols);
|
||||
unimplemented!();
|
||||
typeinference_helper.assume_init()
|
||||
}
|
||||
#[cfg(not(feature="no_unsound_assume_init"))] { crate::base::Matrix::new_uninitialized_generic($nrows, $ncols).assume_init() }
|
||||
}}
|
||||
}
|
||||
|
||||
/// # Generic constructors
|
||||
/// This set of matrix and vector construction functions are all generic
|
||||
/// with-regard to the matrix dimensions. They all expect to be given
|
||||
|
@ -38,8 +56,8 @@ where
|
|||
/// Creates a new uninitialized matrix. If the matrix has a compile-time dimension, this panics
|
||||
/// if `nrows != R::to_usize()` or `ncols != C::to_usize()`.
|
||||
#[inline]
|
||||
pub unsafe fn new_uninitialized_generic(nrows: R, ncols: C) -> Self {
|
||||
Self::from_data(DefaultAllocator::allocate_uninitialized(nrows, ncols))
|
||||
pub unsafe fn new_uninitialized_generic(nrows: R, ncols: C) -> mem::MaybeUninit<Self> {
|
||||
Self::from_uninitialized_data(DefaultAllocator::allocate_uninitialized(nrows, ncols))
|
||||
}
|
||||
|
||||
/// Creates a matrix with all its elements set to `elem`.
|
||||
|
@ -88,7 +106,7 @@ where
|
|||
"Matrix init. error: the slice did not contain the right number of elements."
|
||||
);
|
||||
|
||||
let mut res = unsafe { Self::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut res = unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
let mut iter = slice.iter();
|
||||
|
||||
for i in 0..nrows.value() {
|
||||
|
@ -114,7 +132,7 @@ where
|
|||
where
|
||||
F: FnMut(usize, usize) -> N,
|
||||
{
|
||||
let mut res = unsafe { Self::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut res: Self = unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
|
||||
for j in 0..ncols.value() {
|
||||
for i in 0..nrows.value() {
|
||||
|
@ -356,7 +374,7 @@ macro_rules! impl_constructors(
|
|||
($($Dims: ty),*; $(=> $DimIdent: ident: $DimBound: ident),*; $($gargs: expr),*; $($args: ident),*) => {
|
||||
/// Creates a new uninitialized matrix or vector.
|
||||
#[inline]
|
||||
pub unsafe fn new_uninitialized($($args: usize),*) -> Self {
|
||||
pub unsafe fn new_uninitialized($($args: usize),*) -> mem::MaybeUninit<Self> {
|
||||
Self::new_uninitialized_generic($($gargs),*)
|
||||
}
|
||||
|
||||
|
@ -806,8 +824,8 @@ where
|
|||
{
|
||||
#[inline]
|
||||
fn sample<'a, G: Rng + ?Sized>(&self, rng: &'a mut G) -> MatrixMN<N, R, C> {
|
||||
let nrows = R::try_to_usize().unwrap_or_else(|| rng.gen_range(0, 10));
|
||||
let ncols = C::try_to_usize().unwrap_or_else(|| rng.gen_range(0, 10));
|
||||
let nrows = R::try_to_usize().unwrap_or_else(|| rng.gen_range(0..10));
|
||||
let ncols = C::try_to_usize().unwrap_or_else(|| rng.gen_range(0..10));
|
||||
|
||||
MatrixMN::from_fn_generic(R::from_usize(nrows), C::from_usize(ncols), |_, _| rng.gen())
|
||||
}
|
||||
|
@ -823,9 +841,9 @@ where
|
|||
Owned<N, R, C>: Clone + Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
let nrows = R::try_to_usize().unwrap_or(g.gen_range(0, 10));
|
||||
let ncols = C::try_to_usize().unwrap_or(g.gen_range(0, 10));
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let nrows = R::try_to_usize().unwrap_or(usize::arbitrary(g) % 10);
|
||||
let ncols = C::try_to_usize().unwrap_or(usize::arbitrary(g) % 10);
|
||||
|
||||
Self::from_fn_generic(R::from_usize(nrows), C::from_usize(ncols), |_, _| {
|
||||
N::arbitrary(g)
|
||||
|
@ -865,7 +883,10 @@ macro_rules! componentwise_constructors_impl(
|
|||
#[inline]
|
||||
pub fn new($($args: N),*) -> Self {
|
||||
unsafe {
|
||||
let mut res = Self::new_uninitialized();
|
||||
#[cfg(feature="no_unsound_assume_init")]
|
||||
let mut res: Self = unimplemented!();
|
||||
#[cfg(not(feature="no_unsound_assume_init"))]
|
||||
let mut res = Self::new_uninitialized().assume_init();
|
||||
$( *res.get_unchecked_mut(($irow, $icol)) = $args; )*
|
||||
|
||||
res
|
||||
|
|
|
@ -50,7 +50,8 @@ where
|
|||
let nrows2 = R2::from_usize(nrows);
|
||||
let ncols2 = C2::from_usize(ncols);
|
||||
|
||||
let mut res = unsafe { MatrixMN::<N2, R2, C2>::new_uninitialized_generic(nrows2, ncols2) };
|
||||
let mut res: MatrixMN<N2, R2, C2> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows2, ncols2) };
|
||||
for i in 0..nrows {
|
||||
for j in 0..ncols {
|
||||
unsafe {
|
||||
|
@ -73,7 +74,7 @@ where
|
|||
let nrows = R1::from_usize(nrows2);
|
||||
let ncols = C1::from_usize(ncols2);
|
||||
|
||||
let mut res = unsafe { Self::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut res: Self = unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
for i in 0..nrows2 {
|
||||
for j in 0..ncols2 {
|
||||
unsafe {
|
||||
|
@ -117,9 +118,9 @@ macro_rules! impl_from_into_asref_1D(
|
|||
fn from(arr: [N; $SZ]) -> Self {
|
||||
unsafe {
|
||||
let mut res = Self::new_uninitialized();
|
||||
ptr::copy_nonoverlapping(&arr[0], res.data.ptr_mut(), $SZ);
|
||||
ptr::copy_nonoverlapping(&arr[0], (*res.as_mut_ptr()).data.ptr_mut(), $SZ);
|
||||
|
||||
res
|
||||
res.assume_init()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -184,9 +185,9 @@ macro_rules! impl_from_into_asref_2D(
|
|||
fn from(arr: [[N; $SZRows]; $SZCols]) -> Self {
|
||||
unsafe {
|
||||
let mut res = Self::new_uninitialized();
|
||||
ptr::copy_nonoverlapping(&arr[0][0], res.data.ptr_mut(), $SZRows * $SZCols);
|
||||
ptr::copy_nonoverlapping(&arr[0][0], (*res.as_mut_ptr()).data.ptr_mut(), $SZRows * $SZCols);
|
||||
|
||||
res
|
||||
res.assume_init()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -244,9 +245,9 @@ macro_rules! impl_from_into_mint_1D(
|
|||
fn from(v: mint::$VT<N>) -> Self {
|
||||
unsafe {
|
||||
let mut res = Self::new_uninitialized();
|
||||
ptr::copy_nonoverlapping(&v.x, res.data.ptr_mut(), $SZ);
|
||||
ptr::copy_nonoverlapping(&v.x, (*res.as_mut_ptr()).data.ptr_mut(), $SZ);
|
||||
|
||||
res
|
||||
res.assume_init()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -306,13 +307,13 @@ macro_rules! impl_from_into_mint_2D(
|
|||
fn from(m: mint::$MV<N>) -> Self {
|
||||
unsafe {
|
||||
let mut res = Self::new_uninitialized();
|
||||
let mut ptr = res.data.ptr_mut();
|
||||
let mut ptr = (*res.as_mut_ptr()).data.ptr_mut();
|
||||
$(
|
||||
ptr::copy_nonoverlapping(&m.$component.x, ptr, $SZRows);
|
||||
ptr = ptr.offset($SZRows);
|
||||
)*
|
||||
let _ = ptr;
|
||||
res
|
||||
res.assume_init()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,9 +45,8 @@ where
|
|||
type Buffer = ArrayStorage<N, R, C>;
|
||||
|
||||
#[inline]
|
||||
unsafe fn allocate_uninitialized(_: R, _: C) -> Self::Buffer {
|
||||
// TODO: Undefined behavior, see #556
|
||||
mem::MaybeUninit::<Self::Buffer>::uninit().assume_init()
|
||||
unsafe fn allocate_uninitialized(_: R, _: C) -> mem::MaybeUninit<Self::Buffer> {
|
||||
mem::MaybeUninit::<Self::Buffer>::uninit()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -56,7 +55,10 @@ where
|
|||
ncols: C,
|
||||
iter: I,
|
||||
) -> Self::Buffer {
|
||||
let mut res = unsafe { Self::allocate_uninitialized(nrows, ncols) };
|
||||
#[cfg(feature = "no_unsound_assume_init")]
|
||||
let mut res: Self::Buffer = unimplemented!();
|
||||
#[cfg(not(feature = "no_unsound_assume_init"))]
|
||||
let mut res = unsafe { Self::allocate_uninitialized(nrows, ncols).assume_init() };
|
||||
let mut count = 0;
|
||||
|
||||
for (res, e) in res.iter_mut().zip(iter.into_iter()) {
|
||||
|
@ -80,13 +82,13 @@ impl<N: Scalar, C: Dim> Allocator<N, Dynamic, C> for DefaultAllocator {
|
|||
type Buffer = VecStorage<N, Dynamic, C>;
|
||||
|
||||
#[inline]
|
||||
unsafe fn allocate_uninitialized(nrows: Dynamic, ncols: C) -> Self::Buffer {
|
||||
unsafe fn allocate_uninitialized(nrows: Dynamic, ncols: C) -> mem::MaybeUninit<Self::Buffer> {
|
||||
let mut res = Vec::new();
|
||||
let length = nrows.value() * ncols.value();
|
||||
res.reserve_exact(length);
|
||||
res.set_len(length);
|
||||
|
||||
VecStorage::new(nrows, ncols, res)
|
||||
mem::MaybeUninit::new(VecStorage::new(nrows, ncols, res))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -110,13 +112,13 @@ impl<N: Scalar, R: DimName> Allocator<N, R, Dynamic> for DefaultAllocator {
|
|||
type Buffer = VecStorage<N, R, Dynamic>;
|
||||
|
||||
#[inline]
|
||||
unsafe fn allocate_uninitialized(nrows: R, ncols: Dynamic) -> Self::Buffer {
|
||||
unsafe fn allocate_uninitialized(nrows: R, ncols: Dynamic) -> mem::MaybeUninit<Self::Buffer> {
|
||||
let mut res = Vec::new();
|
||||
let length = nrows.value() * ncols.value();
|
||||
res.reserve_exact(length);
|
||||
res.set_len(length);
|
||||
|
||||
VecStorage::new(nrows, ncols, res)
|
||||
mem::MaybeUninit::new(VecStorage::new(nrows, ncols, res))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
|
@ -156,7 +158,11 @@ where
|
|||
cto: CTo,
|
||||
buf: <Self as Allocator<N, RFrom, CFrom>>::Buffer,
|
||||
) -> ArrayStorage<N, RTo, CTo> {
|
||||
let mut res = <Self as Allocator<N, RTo, CTo>>::allocate_uninitialized(rto, cto);
|
||||
#[cfg(feature = "no_unsound_assume_init")]
|
||||
let mut res: ArrayStorage<N, RTo, CTo> = unimplemented!();
|
||||
#[cfg(not(feature = "no_unsound_assume_init"))]
|
||||
let mut res =
|
||||
<Self as Allocator<N, RTo, CTo>>::allocate_uninitialized(rto, cto).assume_init();
|
||||
|
||||
let (rfrom, cfrom) = buf.shape();
|
||||
|
||||
|
@ -184,7 +190,11 @@ where
|
|||
cto: CTo,
|
||||
buf: ArrayStorage<N, RFrom, CFrom>,
|
||||
) -> VecStorage<N, Dynamic, CTo> {
|
||||
let mut res = <Self as Allocator<N, Dynamic, CTo>>::allocate_uninitialized(rto, cto);
|
||||
#[cfg(feature = "no_unsound_assume_init")]
|
||||
let mut res: VecStorage<N, Dynamic, CTo> = unimplemented!();
|
||||
#[cfg(not(feature = "no_unsound_assume_init"))]
|
||||
let mut res =
|
||||
<Self as Allocator<N, Dynamic, CTo>>::allocate_uninitialized(rto, cto).assume_init();
|
||||
|
||||
let (rfrom, cfrom) = buf.shape();
|
||||
|
||||
|
@ -212,7 +222,11 @@ where
|
|||
cto: Dynamic,
|
||||
buf: ArrayStorage<N, RFrom, CFrom>,
|
||||
) -> VecStorage<N, RTo, Dynamic> {
|
||||
let mut res = <Self as Allocator<N, RTo, Dynamic>>::allocate_uninitialized(rto, cto);
|
||||
#[cfg(feature = "no_unsound_assume_init")]
|
||||
let mut res: VecStorage<N, RTo, Dynamic> = unimplemented!();
|
||||
#[cfg(not(feature = "no_unsound_assume_init"))]
|
||||
let mut res =
|
||||
<Self as Allocator<N, RTo, Dynamic>>::allocate_uninitialized(rto, cto).assume_init();
|
||||
|
||||
let (rfrom, cfrom) = buf.shape();
|
||||
|
||||
|
|
|
@ -54,8 +54,9 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
{
|
||||
let irows = irows.into_iter();
|
||||
let ncols = self.data.shape().1;
|
||||
let mut res =
|
||||
unsafe { MatrixMN::new_uninitialized_generic(Dynamic::new(irows.len()), ncols) };
|
||||
let mut res = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(Dynamic::new(irows.len()), ncols)
|
||||
};
|
||||
|
||||
// First, check that all the indices from irows are valid.
|
||||
// This will allow us to use unchecked access in the inner loop.
|
||||
|
@ -89,8 +90,9 @@ impl<N: Scalar + Zero, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
{
|
||||
let icols = icols.into_iter();
|
||||
let nrows = self.data.shape().0;
|
||||
let mut res =
|
||||
unsafe { MatrixMN::new_uninitialized_generic(nrows, Dynamic::new(icols.len())) };
|
||||
let mut res = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(nrows, Dynamic::new(icols.len()))
|
||||
};
|
||||
|
||||
for (destination, source) in icols.enumerate() {
|
||||
res.column_mut(destination).copy_from(&self.column(*source))
|
||||
|
@ -896,7 +898,9 @@ impl<N: Scalar> DMatrix<N> {
|
|||
where
|
||||
DefaultAllocator: Reallocator<N, Dynamic, Dynamic, Dynamic, Dynamic>,
|
||||
{
|
||||
let placeholder = unsafe { Self::new_uninitialized(0, 0) };
|
||||
let placeholder = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(Dynamic::new(0), Dynamic::new(0))
|
||||
};
|
||||
let old = mem::replace(self, placeholder);
|
||||
let new = old.resize(new_nrows, new_ncols, val);
|
||||
let _ = mem::replace(self, new);
|
||||
|
@ -919,8 +923,9 @@ where
|
|||
where
|
||||
DefaultAllocator: Reallocator<N, Dynamic, C, Dynamic, C>,
|
||||
{
|
||||
let placeholder =
|
||||
unsafe { Self::new_uninitialized_generic(Dynamic::new(0), self.data.shape().1) };
|
||||
let placeholder = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(Dynamic::new(0), self.data.shape().1)
|
||||
};
|
||||
let old = mem::replace(self, placeholder);
|
||||
let new = old.resize_vertically(new_nrows, val);
|
||||
let _ = mem::replace(self, new);
|
||||
|
@ -943,8 +948,9 @@ where
|
|||
where
|
||||
DefaultAllocator: Reallocator<N, R, Dynamic, R, Dynamic>,
|
||||
{
|
||||
let placeholder =
|
||||
unsafe { Self::new_uninitialized_generic(self.data.shape().0, Dynamic::new(0)) };
|
||||
let placeholder = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, Dynamic::new(0))
|
||||
};
|
||||
let old = mem::replace(self, placeholder);
|
||||
let new = old.resize_horizontally(new_ncols, val);
|
||||
let _ = mem::replace(self, new);
|
||||
|
|
|
@ -7,7 +7,7 @@ use rand::Rng;
|
|||
#[cfg(feature = "arbitrary")]
|
||||
#[doc(hidden)]
|
||||
#[inline]
|
||||
pub fn reject<G: Gen, F: FnMut(&T) -> bool, T: Arbitrary>(g: &mut G, f: F) -> T {
|
||||
pub fn reject<F: FnMut(&T) -> bool, T: Arbitrary>(g: &mut Gen, f: F) -> T {
|
||||
use std::iter;
|
||||
iter::repeat(())
|
||||
.map(|_| Arbitrary::arbitrary(g))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
//! Matrix iterators.
|
||||
|
||||
use std::iter::FusedIterator;
|
||||
use std::marker::PhantomData;
|
||||
use std::mem;
|
||||
|
||||
|
@ -111,6 +112,46 @@ macro_rules! iterator {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a, N: Scalar, R: Dim, C: Dim, S: 'a + $Storage<N, R, C>> DoubleEndedIterator
|
||||
for $Name<'a, N, R, C, S>
|
||||
{
|
||||
#[inline]
|
||||
fn next_back(&mut self) -> Option<$Ref> {
|
||||
unsafe {
|
||||
if self.size == 0 {
|
||||
None
|
||||
} else {
|
||||
// Pre-decrement `size` such that it now counts to the
|
||||
// element we want to return.
|
||||
self.size -= 1;
|
||||
|
||||
// Fetch strides
|
||||
let inner_stride = self.strides.0.value();
|
||||
let outer_stride = self.strides.1.value();
|
||||
|
||||
// Compute number of rows
|
||||
// Division should be exact
|
||||
let inner_raw_size = self.inner_end.offset_from(self.inner_ptr) as usize;
|
||||
let inner_size = inner_raw_size / inner_stride;
|
||||
|
||||
// Compute rows and cols remaining
|
||||
let outer_remaining = self.size / inner_size;
|
||||
let inner_remaining = self.size % inner_size;
|
||||
|
||||
// Compute pointer to last element
|
||||
let last = self.ptr.offset(
|
||||
(outer_remaining * outer_stride + inner_remaining * inner_stride)
|
||||
as isize,
|
||||
);
|
||||
|
||||
// We want either `& *last` or `&mut *last` here, depending
|
||||
// on the mutability of `$Ref`.
|
||||
Some(mem::transmute(last))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, N: Scalar, R: Dim, C: Dim, S: 'a + $Storage<N, R, C>> ExactSizeIterator
|
||||
for $Name<'a, N, R, C, S>
|
||||
{
|
||||
|
@ -119,6 +160,11 @@ macro_rules! iterator {
|
|||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, N: Scalar, R: Dim, C: Dim, S: 'a + $Storage<N, R, C>> FusedIterator
|
||||
for $Name<'a, N, R, C, S>
|
||||
{
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -279,6 +279,22 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> matrixcompare_core::DenseAc
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> bytemuck::Zeroable
|
||||
for Matrix<N, R, C, S>
|
||||
where
|
||||
S: bytemuck::Zeroable,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> bytemuck::Pod for Matrix<N, R, C, S>
|
||||
where
|
||||
S: bytemuck::Pod,
|
||||
Self: Copy,
|
||||
{
|
||||
}
|
||||
|
||||
impl<N: Scalar, R: Dim, C: Dim, S> Matrix<N, R, C, S> {
|
||||
/// Creates a new matrix with the given data without statically checking that the matrix
|
||||
/// dimension matches the storage dimension.
|
||||
|
@ -298,6 +314,21 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
unsafe { Self::from_data_statically_unchecked(data) }
|
||||
}
|
||||
|
||||
/// Creates a new uninitialized matrix with the given uninitialized data
|
||||
pub unsafe fn from_uninitialized_data(data: mem::MaybeUninit<S>) -> mem::MaybeUninit<Self> {
|
||||
let res: Matrix<N, R, C, mem::MaybeUninit<S>> = Matrix {
|
||||
data,
|
||||
_phantoms: PhantomData,
|
||||
};
|
||||
let res: mem::MaybeUninit<Matrix<N, R, C, mem::MaybeUninit<S>>> =
|
||||
mem::MaybeUninit::new(res);
|
||||
// safety: since we wrap the inner MaybeUninit in an outer MaybeUninit above, the fact that the `data` field is partially-uninitialized is still opaque.
|
||||
// with s/transmute_copy/transmute/, rustc claims that `MaybeUninit<Matrix<N, R, C, MaybeUninit<S>>>` may be of a different size from `MaybeUninit<Matrix<N, R, C, S>>`
|
||||
// but MaybeUninit's documentation says "MaybeUninit<T> is guaranteed to have the same size, alignment, and ABI as T", which implies those types should be the same size
|
||||
let res: mem::MaybeUninit<Matrix<N, R, C, S>> = mem::transmute_copy(&res);
|
||||
res
|
||||
}
|
||||
|
||||
/// The shape of this matrix returned as the tuple (number of rows, number of columns).
|
||||
///
|
||||
/// # Examples:
|
||||
|
@ -497,7 +528,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
let ncols: SameShapeC<C, C2> = Dim::from_usize(ncols);
|
||||
|
||||
let mut res: MatrixSum<N, R, C, R2, C2> =
|
||||
unsafe { Matrix::new_uninitialized_generic(nrows, ncols) };
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
|
||||
// TODO: use copy_from
|
||||
for j in 0..res.ncols() {
|
||||
|
@ -546,7 +577,7 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
let (nrows, ncols) = self.data.shape();
|
||||
|
||||
unsafe {
|
||||
let mut res = Matrix::new_uninitialized_generic(ncols, nrows);
|
||||
let mut res = crate::unimplemented_or_uninitialized_generic!(ncols, nrows);
|
||||
self.transpose_to(&mut res);
|
||||
|
||||
res
|
||||
|
@ -564,7 +595,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
{
|
||||
let (nrows, ncols) = self.data.shape();
|
||||
|
||||
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut res: MatrixMN<N2, R, C> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
|
||||
for j in 0..ncols.value() {
|
||||
for i in 0..nrows.value() {
|
||||
|
@ -608,7 +640,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
{
|
||||
let (nrows, ncols) = self.data.shape();
|
||||
|
||||
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut res: MatrixMN<N2, R, C> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
|
||||
for j in 0..ncols.value() {
|
||||
for i in 0..nrows.value() {
|
||||
|
@ -635,7 +668,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
{
|
||||
let (nrows, ncols) = self.data.shape();
|
||||
|
||||
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut res: MatrixMN<N3, R, C> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
|
||||
assert_eq!(
|
||||
(nrows.value(), ncols.value()),
|
||||
|
@ -676,7 +710,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
{
|
||||
let (nrows, ncols) = self.data.shape();
|
||||
|
||||
let mut res = unsafe { MatrixMN::new_uninitialized_generic(nrows, ncols) };
|
||||
let mut res: MatrixMN<N4, R, C> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(nrows, ncols) };
|
||||
|
||||
assert_eq!(
|
||||
(nrows.value(), ncols.value()),
|
||||
|
@ -1170,7 +1205,8 @@ impl<N: SimdComplexField, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S
|
|||
let (nrows, ncols) = self.data.shape();
|
||||
|
||||
unsafe {
|
||||
let mut res: MatrixMN<_, C, R> = Matrix::new_uninitialized_generic(ncols, nrows);
|
||||
let mut res: MatrixMN<_, C, R> =
|
||||
crate::unimplemented_or_uninitialized_generic!(ncols, nrows);
|
||||
self.adjoint_to(&mut res);
|
||||
|
||||
res
|
||||
|
@ -1311,7 +1347,8 @@ impl<N: Scalar, D: Dim, S: Storage<N, D, D>> SquareMatrix<N, D, S> {
|
|||
);
|
||||
|
||||
let dim = self.data.shape().0;
|
||||
let mut res = unsafe { VectorN::new_uninitialized_generic(dim, U1) };
|
||||
let mut res: VectorN<N2, D> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(dim, U1) };
|
||||
|
||||
for i in 0..dim.value() {
|
||||
unsafe {
|
||||
|
@ -1438,7 +1475,8 @@ impl<N: Scalar + Zero, D: DimAdd<U1>, S: Storage<N, D>> Vector<N, D, S> {
|
|||
{
|
||||
let len = self.len();
|
||||
let hnrows = DimSum::<D, U1>::from_usize(len + 1);
|
||||
let mut res = unsafe { VectorN::<N, _>::new_uninitialized_generic(hnrows, U1) };
|
||||
let mut res: VectorN<N, _> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(hnrows, U1) };
|
||||
res.generic_slice_mut((0, 0), self.data.shape())
|
||||
.copy_from(self);
|
||||
res[(len, 0)] = element;
|
||||
|
@ -1783,7 +1821,8 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
|
|||
// TODO: soooo ugly!
|
||||
let nrows = SameShapeR::<R, R2>::from_usize(3);
|
||||
let ncols = SameShapeC::<C, C2>::from_usize(1);
|
||||
let mut res = Matrix::new_uninitialized_generic(nrows, ncols);
|
||||
let mut res: MatrixCross<N, R, C, R2, C2> =
|
||||
crate::unimplemented_or_uninitialized_generic!(nrows, ncols);
|
||||
|
||||
let ax = self.get_unchecked((0, 0));
|
||||
let ay = self.get_unchecked((1, 0));
|
||||
|
@ -1807,7 +1846,8 @@ impl<N: Scalar + ClosedAdd + ClosedSub + ClosedMul, R: Dim, C: Dim, S: Storage<N
|
|||
// TODO: ugly!
|
||||
let nrows = SameShapeR::<R, R2>::from_usize(1);
|
||||
let ncols = SameShapeC::<C, C2>::from_usize(3);
|
||||
let mut res = Matrix::new_uninitialized_generic(nrows, ncols);
|
||||
let mut res: MatrixCross<N, R, C, R2, C2> =
|
||||
crate::unimplemented_or_uninitialized_generic!(nrows, ncols);
|
||||
|
||||
let ax = self.get_unchecked((0, 0));
|
||||
let ay = self.get_unchecked((0, 1));
|
||||
|
|
|
@ -433,8 +433,8 @@ where
|
|||
"Matrix meet/join error: mismatched dimensions."
|
||||
);
|
||||
|
||||
let mut mres = unsafe { Self::new_uninitialized_generic(shape.0, shape.1) };
|
||||
let mut jres = unsafe { Self::new_uninitialized_generic(shape.0, shape.1) };
|
||||
let mut mres = unsafe { crate::unimplemented_or_uninitialized_generic!(shape.0, shape.1) };
|
||||
let mut jres = unsafe { crate::unimplemented_or_uninitialized_generic!(shape.0, shape.1) };
|
||||
|
||||
for i in 0..shape.0.value() * shape.1.value() {
|
||||
unsafe {
|
||||
|
|
|
@ -15,6 +15,7 @@ mod alias_slice;
|
|||
mod array_storage;
|
||||
mod cg;
|
||||
mod componentwise;
|
||||
#[macro_use]
|
||||
mod construction;
|
||||
mod construction_slice;
|
||||
mod conversion;
|
||||
|
|
|
@ -8,7 +8,7 @@ use crate::allocator::Allocator;
|
|||
use crate::base::{DefaultAllocator, Dim, DimName, Matrix, MatrixMN, Normed, VectorN};
|
||||
use crate::constraint::{SameNumberOfColumns, SameNumberOfRows, ShapeConstraint};
|
||||
use crate::storage::{Storage, StorageMut};
|
||||
use crate::{ComplexField, Scalar, SimdComplexField, Unit};
|
||||
use crate::{ComplexField, RealField, Scalar, SimdComplexField, Unit};
|
||||
use simba::scalar::ClosedNeg;
|
||||
use simba::simd::{SimdOption, SimdPartialOrd};
|
||||
|
||||
|
@ -334,11 +334,27 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
{
|
||||
let n = self.norm();
|
||||
|
||||
if n >= min_magnitude {
|
||||
if n > min_magnitude {
|
||||
self.scale_mut(magnitude / n)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new vector with the same magnitude as `self` clamped between `0.0` and `max`.
|
||||
#[inline]
|
||||
pub fn cap_magnitude(&self, max: N::RealField) -> MatrixMN<N, R, C>
|
||||
where
|
||||
N: RealField,
|
||||
DefaultAllocator: Allocator<N, R, C>,
|
||||
{
|
||||
let n = self.norm();
|
||||
|
||||
if n > max {
|
||||
self.scale(max / n)
|
||||
} else {
|
||||
self.clone_owned()
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a normalized version of this matrix unless its norm as smaller or equal to `eps`.
|
||||
///
|
||||
/// The components of this matrix cannot be SIMD types (see `simd_try_normalize`) instead.
|
||||
|
|
|
@ -331,7 +331,7 @@ macro_rules! componentwise_binop_impl(
|
|||
let (nrows, ncols) = self.shape();
|
||||
let nrows: SameShapeR<R1, R2> = Dim::from_usize(nrows);
|
||||
let ncols: SameShapeC<C1, C2> = Dim::from_usize(ncols);
|
||||
Matrix::new_uninitialized_generic(nrows, ncols)
|
||||
crate::unimplemented_or_uninitialized_generic!(nrows, ncols)
|
||||
};
|
||||
|
||||
self.$method_to_statically_unchecked(rhs, &mut res);
|
||||
|
@ -573,9 +573,9 @@ where
|
|||
|
||||
#[inline]
|
||||
fn mul(self, rhs: &'b Matrix<N, R2, C2, SB>) -> Self::Output {
|
||||
let mut res =
|
||||
unsafe { Matrix::new_uninitialized_generic(self.data.shape().0, rhs.data.shape().1) };
|
||||
|
||||
let mut res = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(self.data.shape().0, rhs.data.shape().1)
|
||||
};
|
||||
self.mul_to(rhs, &mut res);
|
||||
res
|
||||
}
|
||||
|
@ -684,8 +684,9 @@ where
|
|||
DefaultAllocator: Allocator<N, C1, C2>,
|
||||
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
||||
{
|
||||
let mut res =
|
||||
unsafe { Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1) };
|
||||
let mut res = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
|
||||
};
|
||||
|
||||
self.tr_mul_to(rhs, &mut res);
|
||||
res
|
||||
|
@ -700,8 +701,9 @@ where
|
|||
DefaultAllocator: Allocator<N, C1, C2>,
|
||||
ShapeConstraint: SameNumberOfRows<R1, R2>,
|
||||
{
|
||||
let mut res =
|
||||
unsafe { Matrix::new_uninitialized_generic(self.data.shape().1, rhs.data.shape().1) };
|
||||
let mut res = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(self.data.shape().1, rhs.data.shape().1)
|
||||
};
|
||||
|
||||
self.ad_mul_to(rhs, &mut res);
|
||||
res
|
||||
|
@ -815,8 +817,9 @@ where
|
|||
let (nrows1, ncols1) = self.data.shape();
|
||||
let (nrows2, ncols2) = rhs.data.shape();
|
||||
|
||||
let mut res =
|
||||
unsafe { Matrix::new_uninitialized_generic(nrows1.mul(nrows2), ncols1.mul(ncols2)) };
|
||||
let mut res = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(nrows1.mul(nrows2), ncols1.mul(ncols2))
|
||||
};
|
||||
|
||||
{
|
||||
let mut data_res = res.data.ptr_mut();
|
||||
|
|
|
@ -17,7 +17,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
DefaultAllocator: Allocator<N, U1, C>,
|
||||
{
|
||||
let ncols = self.data.shape().1;
|
||||
let mut res = unsafe { RowVectorN::new_uninitialized_generic(U1, ncols) };
|
||||
let mut res: RowVectorN<N, C> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(U1, ncols) };
|
||||
|
||||
for i in 0..ncols.value() {
|
||||
// TODO: avoid bound checking of column.
|
||||
|
@ -42,7 +43,8 @@ impl<N: Scalar, R: Dim, C: Dim, S: Storage<N, R, C>> Matrix<N, R, C, S> {
|
|||
DefaultAllocator: Allocator<N, C>,
|
||||
{
|
||||
let ncols = self.data.shape().1;
|
||||
let mut res = unsafe { VectorN::new_uninitialized_generic(ncols, U1) };
|
||||
let mut res: VectorN<N, C> =
|
||||
unsafe { crate::unimplemented_or_uninitialized_generic!(ncols, U1) };
|
||||
|
||||
for i in 0..ncols.value() {
|
||||
// TODO: avoid bound checking of column.
|
||||
|
|
|
@ -30,6 +30,12 @@ pub struct Unit<T> {
|
|||
pub(crate) value: T,
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<T> bytemuck::Zeroable for Unit<T> where T: bytemuck::Zeroable {}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<T> bytemuck::Pod for Unit<T> where T: bytemuck::Pod {}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<T: Serialize> Serialize for Unit<T> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
|
|
|
@ -48,9 +48,8 @@ where
|
|||
DefaultAllocator: Allocator<N, D, D>,
|
||||
Owned<N, D, D>: Clone + Send,
|
||||
{
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
use rand::Rng;
|
||||
let dim = D::try_to_usize().unwrap_or(g.gen_range(1, 50));
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let dim = D::try_to_usize().unwrap_or(1 + usize::arbitrary(g) % 50);
|
||||
Self::new(D::from_usize(dim), || N::arbitrary(g))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,9 +51,8 @@ where
|
|||
DefaultAllocator: Allocator<N, D, D>,
|
||||
Owned<N, D, D>: Clone + Send,
|
||||
{
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
use rand::Rng;
|
||||
let dim = D::try_to_usize().unwrap_or(g.gen_range(1, 50));
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let dim = D::try_to_usize().unwrap_or(1 + usize::arbitrary(g) % 50);
|
||||
Self::new(D::from_usize(dim), || N::arbitrary(g))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,13 @@
|
|||
use crate::{Quaternion, SimdRealField};
|
||||
use crate::{
|
||||
Isometry3, Matrix4, Normed, Point3, Quaternion, Scalar, SimdRealField, Translation3, Unit,
|
||||
UnitQuaternion, Vector3, VectorN, Zero, U8,
|
||||
};
|
||||
use approx::{AbsDiffEq, RelativeEq, UlpsEq};
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::fmt;
|
||||
|
||||
use simba::scalar::{ClosedNeg, RealField};
|
||||
|
||||
/// A dual quaternion.
|
||||
///
|
||||
|
@ -28,14 +35,23 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|||
/// If a feature that you need is missing, feel free to open an issue or a PR.
|
||||
/// See https://github.com/dimforge/nalgebra/issues/487
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Default, Eq, PartialEq, Copy, Clone)]
|
||||
pub struct DualQuaternion<N: SimdRealField> {
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
|
||||
pub struct DualQuaternion<N: Scalar> {
|
||||
/// The real component of the quaternion
|
||||
pub real: Quaternion<N>,
|
||||
/// The dual component of the quaternion
|
||||
pub dual: Quaternion<N>,
|
||||
}
|
||||
|
||||
impl<N: Scalar + Zero> Default for DualQuaternion<N> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
real: Quaternion::default(),
|
||||
dual: Quaternion::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> DualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
|
@ -77,8 +93,147 @@ where
|
|||
/// relative_eq!(dq.real.norm(), 1.0);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn normalize_mut(&mut self) {
|
||||
*self = self.normalize();
|
||||
pub fn normalize_mut(&mut self) -> N {
|
||||
let real_norm = self.real.norm();
|
||||
self.real /= real_norm;
|
||||
self.dual /= real_norm;
|
||||
real_norm
|
||||
}
|
||||
|
||||
/// The conjugate of this dual quaternion, containing the conjugate of
|
||||
/// the real and imaginary parts..
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use nalgebra::{DualQuaternion, Quaternion};
|
||||
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let dq = DualQuaternion::from_real_and_dual(real, dual);
|
||||
///
|
||||
/// let conj = dq.conjugate();
|
||||
/// assert!(conj.real.i == -2.0 && conj.real.j == -3.0 && conj.real.k == -4.0);
|
||||
/// assert!(conj.real.w == 1.0);
|
||||
/// assert!(conj.dual.i == -6.0 && conj.dual.j == -7.0 && conj.dual.k == -8.0);
|
||||
/// assert!(conj.dual.w == 5.0);
|
||||
/// ```
|
||||
#[inline]
|
||||
#[must_use = "Did you mean to use conjugate_mut()?"]
|
||||
pub fn conjugate(&self) -> Self {
|
||||
Self::from_real_and_dual(self.real.conjugate(), self.dual.conjugate())
|
||||
}
|
||||
|
||||
/// Replaces this quaternion by its conjugate.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use nalgebra::{DualQuaternion, Quaternion};
|
||||
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let mut dq = DualQuaternion::from_real_and_dual(real, dual);
|
||||
///
|
||||
/// dq.conjugate_mut();
|
||||
/// assert!(dq.real.i == -2.0 && dq.real.j == -3.0 && dq.real.k == -4.0);
|
||||
/// assert!(dq.real.w == 1.0);
|
||||
/// assert!(dq.dual.i == -6.0 && dq.dual.j == -7.0 && dq.dual.k == -8.0);
|
||||
/// assert!(dq.dual.w == 5.0);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn conjugate_mut(&mut self) {
|
||||
self.real.conjugate_mut();
|
||||
self.dual.conjugate_mut();
|
||||
}
|
||||
|
||||
/// Inverts this dual quaternion if it is not zero.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{DualQuaternion, Quaternion};
|
||||
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let dq = DualQuaternion::from_real_and_dual(real, dual);
|
||||
/// let inverse = dq.try_inverse();
|
||||
///
|
||||
/// assert!(inverse.is_some());
|
||||
/// assert_relative_eq!(inverse.unwrap() * dq, DualQuaternion::identity());
|
||||
///
|
||||
/// //Non-invertible case
|
||||
/// let zero = Quaternion::new(0.0, 0.0, 0.0, 0.0);
|
||||
/// let dq = DualQuaternion::from_real_and_dual(zero, zero);
|
||||
/// let inverse = dq.try_inverse();
|
||||
///
|
||||
/// assert!(inverse.is_none());
|
||||
/// ```
|
||||
#[inline]
|
||||
#[must_use = "Did you mean to use try_inverse_mut()?"]
|
||||
pub fn try_inverse(&self) -> Option<Self>
|
||||
where
|
||||
N: RealField,
|
||||
{
|
||||
let mut res = *self;
|
||||
if res.try_inverse_mut() {
|
||||
Some(res)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Inverts this dual quaternion in-place if it is not zero.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{DualQuaternion, Quaternion};
|
||||
/// let real = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let dual = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let dq = DualQuaternion::from_real_and_dual(real, dual);
|
||||
/// let mut dq_inverse = dq;
|
||||
/// dq_inverse.try_inverse_mut();
|
||||
///
|
||||
/// assert_relative_eq!(dq_inverse * dq, DualQuaternion::identity());
|
||||
///
|
||||
/// //Non-invertible case
|
||||
/// let zero = Quaternion::new(0.0, 0.0, 0.0, 0.0);
|
||||
/// let mut dq = DualQuaternion::from_real_and_dual(zero, zero);
|
||||
/// assert!(!dq.try_inverse_mut());
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn try_inverse_mut(&mut self) -> bool
|
||||
where
|
||||
N: RealField,
|
||||
{
|
||||
let inverted = self.real.try_inverse_mut();
|
||||
if inverted {
|
||||
self.dual = -self.real * self.dual * self.real;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Linear interpolation between two dual quaternions.
|
||||
///
|
||||
/// Computes `self * (1 - t) + other * t`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use nalgebra::{DualQuaternion, Quaternion};
|
||||
/// let dq1 = DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(1.0, 0.0, 0.0, 4.0),
|
||||
/// Quaternion::new(0.0, 2.0, 0.0, 0.0)
|
||||
/// );
|
||||
/// let dq2 = DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(2.0, 0.0, 1.0, 0.0),
|
||||
/// Quaternion::new(0.0, 2.0, 0.0, 0.0)
|
||||
/// );
|
||||
/// assert_eq!(dq1.lerp(&dq2, 0.25), DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(1.25, 0.0, 0.25, 3.0),
|
||||
/// Quaternion::new(0.0, 2.0, 0.0, 0.0)
|
||||
/// ));
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn lerp(&self, other: &Self, t: N) -> Self {
|
||||
self * (N::one() - t) + other * t
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -114,3 +269,669 @@ where
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField> DualQuaternion<N> {
|
||||
fn to_vector(&self) -> VectorN<N, U8> {
|
||||
self.as_ref().clone().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + AbsDiffEq<Epsilon = N>> AbsDiffEq for DualQuaternion<N> {
|
||||
type Epsilon = N;
|
||||
|
||||
#[inline]
|
||||
fn default_epsilon() -> Self::Epsilon {
|
||||
N::default_epsilon()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
|
||||
self.to_vector().abs_diff_eq(&other.to_vector(), epsilon) ||
|
||||
// Account for the double-covering of S², i.e. q = -q
|
||||
self.to_vector().iter().zip(other.to_vector().iter()).all(|(a, b)| a.abs_diff_eq(&-*b, epsilon))
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + RelativeEq<Epsilon = N>> RelativeEq for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn default_max_relative() -> Self::Epsilon {
|
||||
N::default_max_relative()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn relative_eq(
|
||||
&self,
|
||||
other: &Self,
|
||||
epsilon: Self::Epsilon,
|
||||
max_relative: Self::Epsilon,
|
||||
) -> bool {
|
||||
self.to_vector().relative_eq(&other.to_vector(), epsilon, max_relative) ||
|
||||
// Account for the double-covering of S², i.e. q = -q
|
||||
self.to_vector().iter().zip(other.to_vector().iter()).all(|(a, b)| a.relative_eq(&-*b, epsilon, max_relative))
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + UlpsEq<Epsilon = N>> UlpsEq for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn default_max_ulps() -> u32 {
|
||||
N::default_max_ulps()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
|
||||
self.to_vector().ulps_eq(&other.to_vector(), epsilon, max_ulps) ||
|
||||
// Account for the double-covering of S², i.e. q = -q.
|
||||
self.to_vector().iter().zip(other.to_vector().iter()).all(|(a, b)| a.ulps_eq(&-*b, epsilon, max_ulps))
|
||||
}
|
||||
}
|
||||
|
||||
/// A unit quaternions. May be used to represent a rotation followed by a translation.
|
||||
pub type UnitDualQuaternion<N> = Unit<DualQuaternion<N>>;
|
||||
|
||||
impl<N: Scalar + ClosedNeg + PartialEq + SimdRealField> PartialEq for UnitDualQuaternion<N> {
|
||||
#[inline]
|
||||
fn eq(&self, rhs: &Self) -> bool {
|
||||
self.as_ref().eq(rhs.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Scalar + ClosedNeg + Eq + SimdRealField> Eq for UnitDualQuaternion<N> {}
|
||||
|
||||
impl<N: SimdRealField> Normed for DualQuaternion<N> {
|
||||
type Norm = N::SimdRealField;
|
||||
|
||||
#[inline]
|
||||
fn norm(&self) -> N::SimdRealField {
|
||||
self.real.norm()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn norm_squared(&self) -> N::SimdRealField {
|
||||
self.real.norm_squared()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn scale_mut(&mut self, n: Self::Norm) {
|
||||
self.real.scale_mut(n);
|
||||
self.dual.scale_mut(n);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn unscale_mut(&mut self, n: Self::Norm) {
|
||||
self.real.unscale_mut(n);
|
||||
self.dual.unscale_mut(n);
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> UnitDualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
/// The underlying dual quaternion.
|
||||
///
|
||||
/// Same as `self.as_ref()`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use nalgebra::{DualQuaternion, UnitDualQuaternion, Quaternion};
|
||||
/// let id = UnitDualQuaternion::identity();
|
||||
/// assert_eq!(*id.dual_quaternion(), DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(1.0, 0.0, 0.0, 0.0),
|
||||
/// Quaternion::new(0.0, 0.0, 0.0, 0.0)
|
||||
/// ));
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn dual_quaternion(&self) -> &DualQuaternion<N> {
|
||||
self.as_ref()
|
||||
}
|
||||
|
||||
/// Compute the conjugate of this unit quaternion.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
|
||||
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let unit = UnitDualQuaternion::new_normalize(
|
||||
/// DualQuaternion::from_real_and_dual(qr, qd)
|
||||
/// );
|
||||
/// let conj = unit.conjugate();
|
||||
/// assert_eq!(conj.real, unit.real.conjugate());
|
||||
/// assert_eq!(conj.dual, unit.dual.conjugate());
|
||||
/// ```
|
||||
#[inline]
|
||||
#[must_use = "Did you mean to use conjugate_mut()?"]
|
||||
pub fn conjugate(&self) -> Self {
|
||||
Self::new_unchecked(self.as_ref().conjugate())
|
||||
}
|
||||
|
||||
/// Compute the conjugate of this unit quaternion in-place.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
|
||||
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let unit = UnitDualQuaternion::new_normalize(
|
||||
/// DualQuaternion::from_real_and_dual(qr, qd)
|
||||
/// );
|
||||
/// let mut conj = unit.clone();
|
||||
/// conj.conjugate_mut();
|
||||
/// assert_eq!(conj.as_ref().real, unit.as_ref().real.conjugate());
|
||||
/// assert_eq!(conj.as_ref().dual, unit.as_ref().dual.conjugate());
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn conjugate_mut(&mut self) {
|
||||
self.as_mut_unchecked().conjugate_mut()
|
||||
}
|
||||
|
||||
/// Inverts this dual quaternion if it is not zero.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, Quaternion, DualQuaternion};
|
||||
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let unit = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qr, qd));
|
||||
/// let inv = unit.inverse();
|
||||
/// assert_relative_eq!(unit * inv, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
|
||||
/// assert_relative_eq!(inv * unit, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
#[must_use = "Did you mean to use inverse_mut()?"]
|
||||
pub fn inverse(&self) -> Self {
|
||||
let real = Unit::new_unchecked(self.as_ref().real)
|
||||
.inverse()
|
||||
.into_inner();
|
||||
let dual = -real * self.as_ref().dual * real;
|
||||
UnitDualQuaternion::new_unchecked(DualQuaternion { real, dual })
|
||||
}
|
||||
|
||||
/// Inverts this dual quaternion in place if it is not zero.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, Quaternion, DualQuaternion};
|
||||
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let unit = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qr, qd));
|
||||
/// let mut inv = unit.clone();
|
||||
/// inv.inverse_mut();
|
||||
/// assert_relative_eq!(unit * inv, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
|
||||
/// assert_relative_eq!(inv * unit, UnitDualQuaternion::identity(), epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
#[must_use = "Did you mean to use inverse_mut()?"]
|
||||
pub fn inverse_mut(&mut self) {
|
||||
let quat = self.as_mut_unchecked();
|
||||
quat.real = Unit::new_unchecked(quat.real).inverse().into_inner();
|
||||
quat.dual = -quat.real * quat.dual * quat.real;
|
||||
}
|
||||
|
||||
/// The unit dual quaternion needed to make `self` and `other` coincide.
|
||||
///
|
||||
/// The result is such that: `self.isometry_to(other) * self == other`.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
|
||||
/// let qr = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let qd = Quaternion::new(5.0, 6.0, 7.0, 8.0);
|
||||
/// let dq1 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qr, qd));
|
||||
/// let dq2 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(qd, qr));
|
||||
/// let dq_to = dq1.isometry_to(&dq2);
|
||||
/// assert_relative_eq!(dq_to * dq1, dq2, epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn isometry_to(&self, other: &Self) -> Self {
|
||||
other / self
|
||||
}
|
||||
|
||||
/// Linear interpolation between two unit dual quaternions.
|
||||
///
|
||||
/// The result is not normalized.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
|
||||
/// let dq1 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(0.5, 0.0, 0.5, 0.0),
|
||||
/// Quaternion::new(0.0, 0.5, 0.0, 0.5)
|
||||
/// ));
|
||||
/// let dq2 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(0.5, 0.0, 0.0, 0.5),
|
||||
/// Quaternion::new(0.5, 0.0, 0.5, 0.0)
|
||||
/// ));
|
||||
/// assert_relative_eq!(
|
||||
/// UnitDualQuaternion::new_normalize(dq1.lerp(&dq2, 0.5)),
|
||||
/// UnitDualQuaternion::new_normalize(
|
||||
/// DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(0.5, 0.0, 0.25, 0.25),
|
||||
/// Quaternion::new(0.25, 0.25, 0.25, 0.25)
|
||||
/// )
|
||||
/// ),
|
||||
/// epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn lerp(&self, other: &Self, t: N) -> DualQuaternion<N> {
|
||||
self.as_ref().lerp(other.as_ref(), t)
|
||||
}
|
||||
|
||||
/// Normalized linear interpolation between two unit quaternions.
|
||||
///
|
||||
/// This is the same as `self.lerp` except that the result is normalized.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, Quaternion};
|
||||
/// let dq1 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(0.5, 0.0, 0.5, 0.0),
|
||||
/// Quaternion::new(0.0, 0.5, 0.0, 0.5)
|
||||
/// ));
|
||||
/// let dq2 = UnitDualQuaternion::new_normalize(DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(0.5, 0.0, 0.0, 0.5),
|
||||
/// Quaternion::new(0.5, 0.0, 0.5, 0.0)
|
||||
/// ));
|
||||
/// assert_relative_eq!(dq1.nlerp(&dq2, 0.2), UnitDualQuaternion::new_normalize(
|
||||
/// DualQuaternion::from_real_and_dual(
|
||||
/// Quaternion::new(0.5, 0.0, 0.4, 0.1),
|
||||
/// Quaternion::new(0.1, 0.4, 0.1, 0.4)
|
||||
/// )
|
||||
/// ), epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn nlerp(&self, other: &Self, t: N) -> Self {
|
||||
let mut res = self.lerp(other, t);
|
||||
let _ = res.normalize_mut();
|
||||
|
||||
Self::new_unchecked(res)
|
||||
}
|
||||
|
||||
/// Screw linear interpolation between two unit quaternions. This creates a
|
||||
/// smooth arc from one dual-quaternion to another.
|
||||
///
|
||||
/// Panics if the angle between both quaternion is 180 degrees (in which case the interpolation
|
||||
/// is not well-defined). Use `.try_sclerp` instead to avoid the panic.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, DualQuaternion, UnitQuaternion, Vector3};
|
||||
///
|
||||
/// let dq1 = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_4, 0.0, 0.0),
|
||||
/// );
|
||||
///
|
||||
/// let dq2 = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 0.0, 3.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(-std::f32::consts::PI, 0.0, 0.0),
|
||||
/// );
|
||||
///
|
||||
/// let dq = dq1.sclerp(&dq2, 1.0 / 3.0);
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.rotation().euler_angles().0, std::f32::consts::FRAC_PI_2, epsilon = 1.0e-6
|
||||
/// );
|
||||
/// assert_relative_eq!(dq.translation().vector.y, 3.0, epsilon = 1.0e-6);
|
||||
#[inline]
|
||||
pub fn sclerp(&self, other: &Self, t: N) -> Self
|
||||
where
|
||||
N: RealField,
|
||||
{
|
||||
self.try_sclerp(other, t, N::default_epsilon())
|
||||
.expect("DualQuaternion sclerp: ambiguous configuration.")
|
||||
}
|
||||
|
||||
/// Computes the screw-linear interpolation between two unit quaternions or returns `None`
|
||||
/// if both quaternions are approximately 180 degrees apart (in which case the interpolation is
|
||||
/// not well-defined).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `self`: the first quaternion to interpolate from.
|
||||
/// * `other`: the second quaternion to interpolate toward.
|
||||
/// * `t`: the interpolation parameter. Should be between 0 and 1.
|
||||
/// * `epsilon`: the value below which the sinus of the angle separating both quaternion
|
||||
/// must be to return `None`.
|
||||
#[inline]
|
||||
pub fn try_sclerp(&self, other: &Self, t: N, epsilon: N) -> Option<Self>
|
||||
where
|
||||
N: RealField,
|
||||
{
|
||||
let two = N::one() + N::one();
|
||||
let half = N::one() / two;
|
||||
|
||||
// Invert one of the quaternions if we've got a longest-path
|
||||
// interpolation.
|
||||
let other = {
|
||||
let dot_product = self.as_ref().real.coords.dot(&other.as_ref().real.coords);
|
||||
if dot_product < N::zero() {
|
||||
-other.clone()
|
||||
} else {
|
||||
other.clone()
|
||||
}
|
||||
};
|
||||
|
||||
let difference = self.as_ref().conjugate() * other.as_ref();
|
||||
let norm_squared = difference.real.vector().norm_squared();
|
||||
if relative_eq!(norm_squared, N::zero(), epsilon = epsilon) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let inverse_norm_squared = N::one() / norm_squared;
|
||||
let inverse_norm = inverse_norm_squared.sqrt();
|
||||
|
||||
let mut angle = two * difference.real.scalar().acos();
|
||||
let mut pitch = -two * difference.dual.scalar() * inverse_norm;
|
||||
let direction = difference.real.vector() * inverse_norm;
|
||||
let moment = (difference.dual.vector()
|
||||
- direction * (pitch * difference.real.scalar() * half))
|
||||
* inverse_norm;
|
||||
|
||||
angle *= t;
|
||||
pitch *= t;
|
||||
|
||||
let sin = (half * angle).sin();
|
||||
let cos = (half * angle).cos();
|
||||
let real = Quaternion::from_parts(cos, direction * sin);
|
||||
let dual = Quaternion::from_parts(
|
||||
-pitch * half * sin,
|
||||
moment * sin + direction * (pitch * half * cos),
|
||||
);
|
||||
|
||||
Some(
|
||||
self * UnitDualQuaternion::new_unchecked(DualQuaternion::from_real_and_dual(
|
||||
real, dual,
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Return the rotation part of this unit dual quaternion.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_4, 0.0, 0.0)
|
||||
/// );
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.rotation().angle(), std::f32::consts::FRAC_PI_4, epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn rotation(&self) -> UnitQuaternion<N> {
|
||||
Unit::new_unchecked(self.as_ref().real)
|
||||
}
|
||||
|
||||
/// Return the translation part of this unit dual quaternion.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_4, 0.0, 0.0)
|
||||
/// );
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.translation().vector, Vector3::new(0.0, 3.0, 0.0), epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn translation(&self) -> Translation3<N> {
|
||||
let two = N::one() + N::one();
|
||||
Translation3::from(
|
||||
((self.as_ref().dual * self.as_ref().real.conjugate()) * two)
|
||||
.vector()
|
||||
.into_owned(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Builds an isometry from this unit dual quaternion.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
|
||||
/// let rotation = UnitQuaternion::from_euler_angles(std::f32::consts::PI, 0.0, 0.0);
|
||||
/// let translation = Vector3::new(1.0, 3.0, 2.5);
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// translation.into(),
|
||||
/// rotation
|
||||
/// );
|
||||
/// let iso = dq.to_isometry();
|
||||
///
|
||||
/// assert_relative_eq!(iso.rotation.angle(), std::f32::consts::PI, epsilon = 1.0e-6);
|
||||
/// assert_relative_eq!(iso.translation.vector, translation, epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn to_isometry(&self) -> Isometry3<N> {
|
||||
Isometry3::from_parts(self.translation(), self.rotation())
|
||||
}
|
||||
|
||||
/// Rotate and translate a point by this unit dual quaternion interpreted
|
||||
/// as an isometry.
|
||||
///
|
||||
/// This is the same as the multiplication `self * pt`.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
|
||||
/// );
|
||||
/// let point = Point3::new(1.0, 2.0, 3.0);
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.transform_point(&point), Point3::new(1.0, 0.0, 2.0), epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn transform_point(&self, pt: &Point3<N>) -> Point3<N> {
|
||||
self * pt
|
||||
}
|
||||
|
||||
/// Rotate a vector by this unit dual quaternion, ignoring the translational
|
||||
/// component.
|
||||
///
|
||||
/// This is the same as the multiplication `self * v`.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
|
||||
/// );
|
||||
/// let vector = Vector3::new(1.0, 2.0, 3.0);
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.transform_vector(&vector), Vector3::new(1.0, -3.0, 2.0), epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
|
||||
self * v
|
||||
}
|
||||
|
||||
/// Rotate and translate a point by the inverse of this unit quaternion.
|
||||
///
|
||||
/// This may be cheaper than inverting the unit dual quaternion and
|
||||
/// transforming the point.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
|
||||
/// );
|
||||
/// let point = Point3::new(1.0, 2.0, 3.0);
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.inverse_transform_point(&point), Point3::new(1.0, 3.0, 1.0), epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn inverse_transform_point(&self, pt: &Point3<N>) -> Point3<N> {
|
||||
self.inverse() * pt
|
||||
}
|
||||
|
||||
/// Rotate a vector by the inverse of this unit quaternion, ignoring the
|
||||
/// translational component.
|
||||
///
|
||||
/// This may be cheaper than inverting the unit dual quaternion and
|
||||
/// transforming the vector.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
|
||||
/// );
|
||||
/// let vector = Vector3::new(1.0, 2.0, 3.0);
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.inverse_transform_vector(&vector), Vector3::new(1.0, 3.0, -2.0), epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn inverse_transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
|
||||
self.inverse() * v
|
||||
}
|
||||
|
||||
/// Rotate a unit vector by the inverse of this unit quaternion, ignoring
|
||||
/// the translational component. This may be
|
||||
/// cheaper than inverting the unit dual quaternion and transforming the
|
||||
/// vector.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Unit, Vector3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
|
||||
/// );
|
||||
/// let vector = Unit::new_unchecked(Vector3::new(0.0, 1.0, 0.0));
|
||||
///
|
||||
/// assert_relative_eq!(
|
||||
/// dq.inverse_transform_unit_vector(&vector),
|
||||
/// Unit::new_unchecked(Vector3::new(0.0, 0.0, -1.0)),
|
||||
/// epsilon = 1.0e-6
|
||||
/// );
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn inverse_transform_unit_vector(&self, v: &Unit<Vector3<N>>) -> Unit<Vector3<N>> {
|
||||
self.inverse() * v
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField + RealField> UnitDualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
/// Converts this unit dual quaternion interpreted as an isometry
|
||||
/// into its equivalent homogeneous transformation matrix.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{Matrix4, UnitDualQuaternion, UnitQuaternion, Vector3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(1.0, 3.0, 2.0).into(),
|
||||
/// UnitQuaternion::from_axis_angle(&Vector3::z_axis(), std::f32::consts::FRAC_PI_6)
|
||||
/// );
|
||||
/// let expected = Matrix4::new(0.8660254, -0.5, 0.0, 1.0,
|
||||
/// 0.5, 0.8660254, 0.0, 3.0,
|
||||
/// 0.0, 0.0, 1.0, 2.0,
|
||||
/// 0.0, 0.0, 0.0, 1.0);
|
||||
///
|
||||
/// assert_relative_eq!(dq.to_homogeneous(), expected, epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn to_homogeneous(&self) -> Matrix4<N> {
|
||||
self.to_isometry().to_homogeneous()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField> Default for UnitDualQuaternion<N> {
|
||||
fn default() -> Self {
|
||||
Self::identity()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + fmt::Display> fmt::Display for UnitDualQuaternion<N> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
if let Some(axis) = self.rotation().axis() {
|
||||
let axis = axis.into_inner();
|
||||
write!(
|
||||
f,
|
||||
"UnitDualQuaternion translation: {} − angle: {} − axis: ({}, {}, {})",
|
||||
self.translation().vector,
|
||||
self.rotation().angle(),
|
||||
axis[0],
|
||||
axis[1],
|
||||
axis[2]
|
||||
)
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"UnitDualQuaternion translation: {} − angle: {} − axis: (undefined)",
|
||||
self.translation().vector,
|
||||
self.rotation().angle()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + AbsDiffEq<Epsilon = N>> AbsDiffEq for UnitDualQuaternion<N> {
|
||||
type Epsilon = N;
|
||||
|
||||
#[inline]
|
||||
fn default_epsilon() -> Self::Epsilon {
|
||||
N::default_epsilon()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
|
||||
self.as_ref().abs_diff_eq(other.as_ref(), epsilon)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + RelativeEq<Epsilon = N>> RelativeEq for UnitDualQuaternion<N> {
|
||||
#[inline]
|
||||
fn default_max_relative() -> Self::Epsilon {
|
||||
N::default_max_relative()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn relative_eq(
|
||||
&self,
|
||||
other: &Self,
|
||||
epsilon: Self::Epsilon,
|
||||
max_relative: Self::Epsilon,
|
||||
) -> bool {
|
||||
self.as_ref()
|
||||
.relative_eq(other.as_ref(), epsilon, max_relative)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + UlpsEq<Epsilon = N>> UlpsEq for UnitDualQuaternion<N> {
|
||||
#[inline]
|
||||
fn default_max_ulps() -> u32 {
|
||||
N::default_max_ulps()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn ulps_eq(&self, other: &Self, epsilon: Self::Epsilon, max_ulps: u32) -> bool {
|
||||
self.as_ref().ulps_eq(other.as_ref(), epsilon, max_ulps)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,324 @@
|
|||
use num::Zero;
|
||||
|
||||
use alga::general::{
|
||||
AbstractGroup, AbstractGroupAbelian, AbstractLoop, AbstractMagma, AbstractModule,
|
||||
AbstractMonoid, AbstractQuasigroup, AbstractSemigroup, Additive, Id, Identity, Module,
|
||||
Multiplicative, RealField, TwoSidedInverse,
|
||||
};
|
||||
use alga::linear::{
|
||||
AffineTransformation, DirectIsometry, FiniteDimVectorSpace, Isometry, NormedSpace,
|
||||
ProjectiveTransformation, Similarity, Transformation, VectorSpace,
|
||||
};
|
||||
|
||||
use crate::base::Vector3;
|
||||
use crate::geometry::{
|
||||
DualQuaternion, Point3, Quaternion, Translation3, UnitDualQuaternion, UnitQuaternion,
|
||||
};
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> Identity<Multiplicative> for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn identity() -> Self {
|
||||
Self::identity()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> Identity<Additive> for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn identity() -> Self {
|
||||
Self::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> AbstractMagma<Multiplicative> for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn operate(&self, rhs: &Self) -> Self {
|
||||
self * rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> AbstractMagma<Additive> for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn operate(&self, rhs: &Self) -> Self {
|
||||
self + rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> TwoSidedInverse<Additive> for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn two_sided_inverse(&self) -> Self {
|
||||
-self
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_structures(
|
||||
($DualQuaternion: ident; $($marker: ident<$operator: ident>),* $(,)*) => {$(
|
||||
impl<N: RealField + simba::scalar::RealField> $marker<$operator> for $DualQuaternion<N> { }
|
||||
)*}
|
||||
);
|
||||
|
||||
impl_structures!(
|
||||
DualQuaternion;
|
||||
AbstractSemigroup<Multiplicative>,
|
||||
AbstractMonoid<Multiplicative>,
|
||||
|
||||
AbstractSemigroup<Additive>,
|
||||
AbstractQuasigroup<Additive>,
|
||||
AbstractMonoid<Additive>,
|
||||
AbstractLoop<Additive>,
|
||||
AbstractGroup<Additive>,
|
||||
AbstractGroupAbelian<Additive>
|
||||
);
|
||||
|
||||
/*
|
||||
*
|
||||
* Vector space.
|
||||
*
|
||||
*/
|
||||
impl<N: RealField + simba::scalar::RealField> AbstractModule for DualQuaternion<N> {
|
||||
type AbstractRing = N;
|
||||
|
||||
#[inline]
|
||||
fn multiply_by(&self, n: N) -> Self {
|
||||
self * n
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> Module for DualQuaternion<N> {
|
||||
type Ring = N;
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> VectorSpace for DualQuaternion<N> {
|
||||
type Field = N;
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> FiniteDimVectorSpace for DualQuaternion<N> {
|
||||
#[inline]
|
||||
fn dimension() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn canonical_basis_element(i: usize) -> Self {
|
||||
if i < 4 {
|
||||
DualQuaternion::from_real_and_dual(
|
||||
Quaternion::canonical_basis_element(i),
|
||||
Quaternion::zero(),
|
||||
)
|
||||
} else {
|
||||
DualQuaternion::from_real_and_dual(
|
||||
Quaternion::zero(),
|
||||
Quaternion::canonical_basis_element(i - 4),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn dot(&self, other: &Self) -> N {
|
||||
self.real.dot(&other.real) + self.dual.dot(&other.dual)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn component_unchecked(&self, i: usize) -> &N {
|
||||
self.as_ref().get_unchecked(i)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
unsafe fn component_unchecked_mut(&mut self, i: usize) -> &mut N {
|
||||
self.as_mut().get_unchecked_mut(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> NormedSpace for DualQuaternion<N> {
|
||||
type RealField = N;
|
||||
type ComplexField = N;
|
||||
|
||||
#[inline]
|
||||
fn norm_squared(&self) -> N {
|
||||
self.real.norm_squared()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn norm(&self) -> N {
|
||||
self.real.norm()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn normalize(&self) -> Self {
|
||||
self.normalize()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn normalize_mut(&mut self) -> N {
|
||||
self.normalize_mut()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn try_normalize(&self, min_norm: N) -> Option<Self> {
|
||||
let real_norm = self.real.norm();
|
||||
if real_norm > min_norm {
|
||||
Some(Self::from_real_and_dual(
|
||||
self.real / real_norm,
|
||||
self.dual / real_norm,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn try_normalize_mut(&mut self, min_norm: N) -> Option<N> {
|
||||
let real_norm = self.real.norm();
|
||||
if real_norm > min_norm {
|
||||
self.real /= real_norm;
|
||||
self.dual /= real_norm;
|
||||
Some(real_norm)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
*
|
||||
* Implementations for UnitDualQuaternion.
|
||||
*
|
||||
*/
|
||||
impl<N: RealField + simba::scalar::RealField> Identity<Multiplicative> for UnitDualQuaternion<N> {
|
||||
#[inline]
|
||||
fn identity() -> Self {
|
||||
Self::identity()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> AbstractMagma<Multiplicative>
|
||||
for UnitDualQuaternion<N>
|
||||
{
|
||||
#[inline]
|
||||
fn operate(&self, rhs: &Self) -> Self {
|
||||
self * rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> TwoSidedInverse<Multiplicative>
|
||||
for UnitDualQuaternion<N>
|
||||
{
|
||||
#[inline]
|
||||
fn two_sided_inverse(&self) -> Self {
|
||||
self.inverse()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn two_sided_inverse_mut(&mut self) {
|
||||
self.inverse_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl_structures!(
|
||||
UnitDualQuaternion;
|
||||
AbstractSemigroup<Multiplicative>,
|
||||
AbstractQuasigroup<Multiplicative>,
|
||||
AbstractMonoid<Multiplicative>,
|
||||
AbstractLoop<Multiplicative>,
|
||||
AbstractGroup<Multiplicative>
|
||||
);
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> Transformation<Point3<N>> for UnitDualQuaternion<N> {
|
||||
#[inline]
|
||||
fn transform_point(&self, pt: &Point3<N>) -> Point3<N> {
|
||||
self.transform_point(pt)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
|
||||
self.transform_vector(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> ProjectiveTransformation<Point3<N>>
|
||||
for UnitDualQuaternion<N>
|
||||
{
|
||||
#[inline]
|
||||
fn inverse_transform_point(&self, pt: &Point3<N>) -> Point3<N> {
|
||||
self.inverse_transform_point(pt)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn inverse_transform_vector(&self, v: &Vector3<N>) -> Vector3<N> {
|
||||
self.inverse_transform_vector(v)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> AffineTransformation<Point3<N>>
|
||||
for UnitDualQuaternion<N>
|
||||
{
|
||||
type Rotation = UnitQuaternion<N>;
|
||||
type NonUniformScaling = Id;
|
||||
type Translation = Translation3<N>;
|
||||
|
||||
#[inline]
|
||||
fn decompose(&self) -> (Self::Translation, Self::Rotation, Id, Self::Rotation) {
|
||||
(
|
||||
self.translation(),
|
||||
self.rotation(),
|
||||
Id::new(),
|
||||
UnitQuaternion::identity(),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn append_translation(&self, translation: &Self::Translation) -> Self {
|
||||
self * Self::from_parts(translation.clone(), UnitQuaternion::identity())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn prepend_translation(&self, translation: &Self::Translation) -> Self {
|
||||
Self::from_parts(translation.clone(), UnitQuaternion::identity()) * self
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn append_rotation(&self, r: &Self::Rotation) -> Self {
|
||||
r * self
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn prepend_rotation(&self, r: &Self::Rotation) -> Self {
|
||||
self * r
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn append_scaling(&self, _: &Self::NonUniformScaling) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn prepend_scaling(&self, _: &Self::NonUniformScaling) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField + simba::scalar::RealField> Similarity<Point3<N>> for UnitDualQuaternion<N> {
|
||||
type Scaling = Id;
|
||||
|
||||
#[inline]
|
||||
fn translation(&self) -> Translation3<N> {
|
||||
self.translation()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn rotation(&self) -> UnitQuaternion<N> {
|
||||
self.rotation()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn scaling(&self) -> Id {
|
||||
Id::new()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! marker_impl(
|
||||
($($Trait: ident),*) => {$(
|
||||
impl<N: RealField + simba::scalar::RealField> $Trait<Point3<N>> for UnitDualQuaternion<N> { }
|
||||
)*}
|
||||
);
|
||||
|
||||
marker_impl!(Isometry, DirectIsometry);
|
|
@ -1,6 +1,12 @@
|
|||
use crate::{DualQuaternion, Quaternion, SimdRealField};
|
||||
use crate::{
|
||||
DualQuaternion, Isometry3, Quaternion, Scalar, SimdRealField, Translation3, UnitDualQuaternion,
|
||||
UnitQuaternion,
|
||||
};
|
||||
use num::{One, Zero};
|
||||
#[cfg(feature = "arbitrary")]
|
||||
use quickcheck::{Arbitrary, Gen};
|
||||
|
||||
impl<N: SimdRealField> DualQuaternion<N> {
|
||||
impl<N: Scalar> DualQuaternion<N> {
|
||||
/// Creates a dual quaternion from its rotation and translation components.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -16,7 +22,8 @@ impl<N: SimdRealField> DualQuaternion<N> {
|
|||
pub fn from_real_and_dual(real: Quaternion<N>, dual: Quaternion<N>) -> Self {
|
||||
Self { real, dual }
|
||||
}
|
||||
/// The dual quaternion multiplicative identity
|
||||
|
||||
/// The dual quaternion multiplicative identity.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
|
@ -33,10 +40,183 @@ impl<N: SimdRealField> DualQuaternion<N> {
|
|||
/// assert_eq!(dq2 * dq1, dq2);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn identity() -> Self {
|
||||
pub fn identity() -> Self
|
||||
where
|
||||
N: SimdRealField,
|
||||
{
|
||||
Self::from_real_and_dual(
|
||||
Quaternion::from_real(N::one()),
|
||||
Quaternion::from_real(N::zero()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> DualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
/// Creates a dual quaternion from only its real part, with no translation
|
||||
/// component.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use nalgebra::{DualQuaternion, Quaternion};
|
||||
/// let rot = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
///
|
||||
/// let dq = DualQuaternion::from_real(rot);
|
||||
/// assert_eq!(dq.real.w, 1.0);
|
||||
/// assert_eq!(dq.dual.w, 0.0);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn from_real(real: Quaternion<N>) -> Self {
|
||||
Self {
|
||||
real,
|
||||
dual: Quaternion::zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> One for DualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn one() -> Self {
|
||||
Self::identity()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> Zero for DualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn zero() -> Self {
|
||||
DualQuaternion::from_real_and_dual(Quaternion::zero(), Quaternion::zero())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_zero(&self) -> bool {
|
||||
self.real.is_zero() && self.dual.is_zero()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "arbitrary")]
|
||||
impl<N> Arbitrary for DualQuaternion<N>
|
||||
where
|
||||
N: SimdRealField + Arbitrary + Send,
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary(rng: &mut Gen) -> Self {
|
||||
Self::from_real_and_dual(Arbitrary::arbitrary(rng), Arbitrary::arbitrary(rng))
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> UnitDualQuaternion<N> {
|
||||
/// The unit dual quaternion multiplicative identity, which also represents
|
||||
/// the identity transformation as an isometry.
|
||||
///
|
||||
/// ```
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
|
||||
/// let ident = UnitDualQuaternion::identity();
|
||||
/// let point = Point3::new(1.0, -4.3, 3.33);
|
||||
///
|
||||
/// assert_eq!(ident * point, point);
|
||||
/// assert_eq!(ident, ident.inverse());
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn identity() -> Self {
|
||||
Self::new_unchecked(DualQuaternion::identity())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> UnitDualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
/// Return a dual quaternion representing the translation and orientation
|
||||
/// given by the provided rotation quaternion and translation vector.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
|
||||
/// let dq = UnitDualQuaternion::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
|
||||
/// );
|
||||
/// let point = Point3::new(1.0, 2.0, 3.0);
|
||||
///
|
||||
/// assert_relative_eq!(dq * point, Point3::new(1.0, 0.0, 2.0), epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn from_parts(translation: Translation3<N>, rotation: UnitQuaternion<N>) -> Self {
|
||||
let half: N = crate::convert(0.5f64);
|
||||
UnitDualQuaternion::new_unchecked(DualQuaternion {
|
||||
real: rotation.clone().into_inner(),
|
||||
dual: Quaternion::from_parts(N::zero(), translation.vector)
|
||||
* rotation.clone().into_inner()
|
||||
* half,
|
||||
})
|
||||
}
|
||||
|
||||
/// Return a unit dual quaternion representing the translation and orientation
|
||||
/// given by the provided isometry.
|
||||
///
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{Isometry3, UnitDualQuaternion, UnitQuaternion, Vector3, Point3};
|
||||
/// let iso = Isometry3::from_parts(
|
||||
/// Vector3::new(0.0, 3.0, 0.0).into(),
|
||||
/// UnitQuaternion::from_euler_angles(std::f32::consts::FRAC_PI_2, 0.0, 0.0)
|
||||
/// );
|
||||
/// let dq = UnitDualQuaternion::from_isometry(&iso);
|
||||
/// let point = Point3::new(1.0, 2.0, 3.0);
|
||||
///
|
||||
/// assert_relative_eq!(dq * point, iso * point, epsilon = 1.0e-6);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn from_isometry(isometry: &Isometry3<N>) -> Self {
|
||||
UnitDualQuaternion::from_parts(isometry.translation, isometry.rotation)
|
||||
}
|
||||
|
||||
/// Creates a dual quaternion from a unit quaternion rotation.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # #[macro_use] extern crate approx;
|
||||
/// # use nalgebra::{UnitQuaternion, UnitDualQuaternion, Quaternion};
|
||||
/// let q = Quaternion::new(1.0, 2.0, 3.0, 4.0);
|
||||
/// let rot = UnitQuaternion::new_normalize(q);
|
||||
///
|
||||
/// let dq = UnitDualQuaternion::from_rotation(rot);
|
||||
/// assert_relative_eq!(dq.as_ref().real.norm(), 1.0, epsilon = 1.0e-6);
|
||||
/// assert_eq!(dq.as_ref().dual.norm(), 0.0);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn from_rotation(rotation: UnitQuaternion<N>) -> Self {
|
||||
Self::new_unchecked(DualQuaternion::from_real(rotation.into_inner()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> One for UnitDualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn one() -> Self {
|
||||
Self::identity()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "arbitrary")]
|
||||
impl<N> Arbitrary for UnitDualQuaternion<N>
|
||||
where
|
||||
N: SimdRealField + Arbitrary + Send,
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary(rng: &mut Gen) -> Self {
|
||||
Self::new_normalize(Arbitrary::arbitrary(rng))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,188 @@
|
|||
use simba::scalar::{RealField, SubsetOf, SupersetOf};
|
||||
use simba::simd::SimdRealField;
|
||||
|
||||
use crate::base::dimension::U3;
|
||||
use crate::base::{Matrix4, Vector4};
|
||||
use crate::geometry::{
|
||||
DualQuaternion, Isometry3, Similarity3, SuperTCategoryOf, TAffine, Transform, Translation3,
|
||||
UnitDualQuaternion, UnitQuaternion,
|
||||
};
|
||||
|
||||
/*
|
||||
* This file provides the following conversions:
|
||||
* =============================================
|
||||
*
|
||||
* DualQuaternion -> DualQuaternion
|
||||
* UnitDualQuaternion -> UnitDualQuaternion
|
||||
* UnitDualQuaternion -> Isometry<U3>
|
||||
* UnitDualQuaternion -> Similarity<U3>
|
||||
* UnitDualQuaternion -> Transform<U3>
|
||||
* UnitDualQuaternion -> Matrix<U4> (homogeneous)
|
||||
*
|
||||
* NOTE:
|
||||
* UnitDualQuaternion -> DualQuaternion is already provided by: Unit<T> -> T
|
||||
*/
|
||||
|
||||
impl<N1, N2> SubsetOf<DualQuaternion<N2>> for DualQuaternion<N1>
|
||||
where
|
||||
N1: SimdRealField,
|
||||
N2: SimdRealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> DualQuaternion<N2> {
|
||||
DualQuaternion::from_real_and_dual(self.real.to_superset(), self.dual.to_superset())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(dq: &DualQuaternion<N2>) -> bool {
|
||||
crate::is_convertible::<_, Vector4<N1>>(&dq.real.coords)
|
||||
&& crate::is_convertible::<_, Vector4<N1>>(&dq.dual.coords)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(dq: &DualQuaternion<N2>) -> Self {
|
||||
DualQuaternion::from_real_and_dual(
|
||||
dq.real.to_subset_unchecked(),
|
||||
dq.dual.to_subset_unchecked(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for UnitDualQuaternion<N1>
|
||||
where
|
||||
N1: SimdRealField,
|
||||
N2: SimdRealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> UnitDualQuaternion<N2> {
|
||||
UnitDualQuaternion::new_unchecked(self.as_ref().to_superset())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
|
||||
crate::is_convertible::<_, DualQuaternion<N1>>(dq.as_ref())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
|
||||
Self::new_unchecked(crate::convert_ref_unchecked(dq.as_ref()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<Isometry3<N2>> for UnitDualQuaternion<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
N2: RealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> Isometry3<N2> {
|
||||
let dq: UnitDualQuaternion<N2> = self.to_superset();
|
||||
let iso = dq.to_isometry();
|
||||
crate::convert_unchecked(iso)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(iso: &Isometry3<N2>) -> bool {
|
||||
crate::is_convertible::<_, UnitQuaternion<N1>>(&iso.rotation)
|
||||
&& crate::is_convertible::<_, Translation3<N1>>(&iso.translation)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(iso: &Isometry3<N2>) -> Self {
|
||||
let dq = UnitDualQuaternion::<N2>::from_isometry(iso);
|
||||
crate::convert_unchecked(dq)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<Similarity3<N2>> for UnitDualQuaternion<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
N2: RealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> Similarity3<N2> {
|
||||
Similarity3::from_isometry(crate::convert_ref(self), N2::one())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(sim: &Similarity3<N2>) -> bool {
|
||||
sim.scaling() == N2::one()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(sim: &Similarity3<N2>) -> Self {
|
||||
crate::convert_ref_unchecked(&sim.isometry)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2, C> SubsetOf<Transform<N2, U3, C>> for UnitDualQuaternion<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
N2: RealField + SupersetOf<N1>,
|
||||
C: SuperTCategoryOf<TAffine>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> Transform<N2, U3, C> {
|
||||
Transform::from_matrix_unchecked(self.to_homogeneous().to_superset())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(t: &Transform<N2, U3, C>) -> bool {
|
||||
<Self as SubsetOf<_>>::is_in_subset(t.matrix())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(t: &Transform<N2, U3, C>) -> Self {
|
||||
Self::from_superset_unchecked(t.matrix())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1: RealField, N2: RealField + SupersetOf<N1>> SubsetOf<Matrix4<N2>>
|
||||
for UnitDualQuaternion<N1>
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> Matrix4<N2> {
|
||||
self.to_homogeneous().to_superset()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(m: &Matrix4<N2>) -> bool {
|
||||
crate::is_convertible::<_, Isometry3<N1>>(m)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(m: &Matrix4<N2>) -> Self {
|
||||
let iso: Isometry3<N1> = crate::convert_ref_unchecked(m);
|
||||
Self::from_isometry(&iso)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField + RealField> From<UnitDualQuaternion<N>> for Matrix4<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn from(dq: UnitDualQuaternion<N>) -> Self {
|
||||
dq.to_homogeneous()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> From<UnitDualQuaternion<N>> for Isometry3<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn from(dq: UnitDualQuaternion<N>) -> Self {
|
||||
dq.to_isometry()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: SimdRealField> From<Isometry3<N>> for UnitDualQuaternion<N>
|
||||
where
|
||||
N::Element: SimdRealField,
|
||||
{
|
||||
#[inline]
|
||||
fn from(iso: Isometry3<N>) -> Self {
|
||||
Self::from_isometry(&iso)
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -102,7 +102,7 @@ where
|
|||
DefaultAllocator: Allocator<N, D>,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(rng: &mut G) -> Self {
|
||||
fn arbitrary(rng: &mut Gen) -> Self {
|
||||
Self::from_parts(Arbitrary::arbitrary(rng), Arbitrary::arbitrary(rng))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,8 @@ use crate::base::dimension::{DimMin, DimName, DimNameAdd, DimNameSum, U1};
|
|||
use crate::base::{DefaultAllocator, MatrixN, Scalar};
|
||||
|
||||
use crate::geometry::{
|
||||
AbstractRotation, Isometry, Similarity, SuperTCategoryOf, TAffine, Transform, Translation,
|
||||
AbstractRotation, Isometry, Isometry3, Similarity, SuperTCategoryOf, TAffine, Transform,
|
||||
Translation, UnitDualQuaternion, UnitQuaternion,
|
||||
};
|
||||
|
||||
/*
|
||||
|
@ -14,6 +15,7 @@ use crate::geometry::{
|
|||
* =============================================
|
||||
*
|
||||
* Isometry -> Isometry
|
||||
* Isometry3 -> UnitDualQuaternion
|
||||
* Isometry -> Similarity
|
||||
* Isometry -> Transform
|
||||
* Isometry -> Matrix (homogeneous)
|
||||
|
@ -47,6 +49,30 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for Isometry3<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
N2: RealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> UnitDualQuaternion<N2> {
|
||||
let dq = UnitDualQuaternion::<N1>::from_isometry(self);
|
||||
dq.to_superset()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
|
||||
crate::is_convertible::<_, UnitQuaternion<N1>>(&dq.rotation())
|
||||
&& crate::is_convertible::<_, Translation<N1, _>>(&dq.translation())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
|
||||
let dq: UnitDualQuaternion<N1> = crate::convert_ref_unchecked(dq);
|
||||
dq.to_isometry()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2, D: DimName, R1, R2> SubsetOf<Similarity<N2, D, R2>> for Isometry<N1, D, R1>
|
||||
where
|
||||
N1: RealField,
|
||||
|
|
|
@ -36,7 +36,10 @@ mod quaternion_ops;
|
|||
mod quaternion_simba;
|
||||
|
||||
mod dual_quaternion;
|
||||
#[cfg(feature = "alga")]
|
||||
mod dual_quaternion_alga;
|
||||
mod dual_quaternion_construction;
|
||||
mod dual_quaternion_conversion;
|
||||
mod dual_quaternion_ops;
|
||||
|
||||
mod unit_complex;
|
||||
|
|
|
@ -705,7 +705,7 @@ impl<N: RealField + Arbitrary> Arbitrary for Orthographic3<N>
|
|||
where
|
||||
Matrix4<N>: Send,
|
||||
{
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let left = Arbitrary::arbitrary(g);
|
||||
let right = helper::reject(g, |x: &N| *x > left);
|
||||
let bottom = Arbitrary::arbitrary(g);
|
||||
|
|
|
@ -283,7 +283,7 @@ where
|
|||
|
||||
#[cfg(feature = "arbitrary")]
|
||||
impl<N: RealField + Arbitrary> Arbitrary for Perspective3<N> {
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let znear = Arbitrary::arbitrary(g);
|
||||
let zfar = helper::reject(g, |&x: &N| !(x - znear).is_zero());
|
||||
let aspect = helper::reject(g, |&x: &N| !x.is_zero());
|
||||
|
|
|
@ -65,6 +65,24 @@ where
|
|||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar, D: DimName> bytemuck::Zeroable for Point<N, D>
|
||||
where
|
||||
VectorN<N, D>: bytemuck::Zeroable,
|
||||
DefaultAllocator: Allocator<N, D>,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar, D: DimName> bytemuck::Pod for Point<N, D>
|
||||
where
|
||||
N: Copy,
|
||||
VectorN<N, D>: bytemuck::Pod,
|
||||
DefaultAllocator: Allocator<N, D>,
|
||||
<DefaultAllocator as Allocator<N, D>>::Buffer: Copy,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "serde-serialize")]
|
||||
impl<N: Scalar, D: DimName> Serialize for Point<N, D>
|
||||
where
|
||||
|
@ -181,7 +199,12 @@ where
|
|||
D: DimNameAdd<U1>,
|
||||
DefaultAllocator: Allocator<N, DimNameSum<D, U1>>,
|
||||
{
|
||||
let mut res = unsafe { VectorN::<_, DimNameSum<D, U1>>::new_uninitialized() };
|
||||
let mut res = unsafe {
|
||||
crate::unimplemented_or_uninitialized_generic!(
|
||||
<DimNameSum<D, U1> as DimName>::name(),
|
||||
U1
|
||||
)
|
||||
};
|
||||
res.fixed_slice_mut::<D, U1>(0, 0).copy_from(&self.coords);
|
||||
res[(D::dim(), 0)] = N::one();
|
||||
|
||||
|
|
|
@ -24,7 +24,10 @@ where
|
|||
/// Creates a new point with uninitialized coordinates.
|
||||
#[inline]
|
||||
pub unsafe fn new_uninitialized() -> Self {
|
||||
Self::from(VectorN::new_uninitialized())
|
||||
Self::from(crate::unimplemented_or_uninitialized_generic!(
|
||||
D::name(),
|
||||
U1
|
||||
))
|
||||
}
|
||||
|
||||
/// Creates a new point with all coordinates equal to zero.
|
||||
|
@ -153,7 +156,7 @@ where
|
|||
<DefaultAllocator as Allocator<N, D>>::Buffer: Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
Self::from(VectorN::arbitrary(g))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,6 +40,17 @@ impl<N: Scalar + Zero> Default for Quaternion<N> {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar> bytemuck::Zeroable for Quaternion<N> where Vector4<N>: bytemuck::Zeroable {}
|
||||
|
||||
#[cfg(feature = "bytemuck")]
|
||||
unsafe impl<N: Scalar> bytemuck::Pod for Quaternion<N>
|
||||
where
|
||||
Vector4<N>: bytemuck::Pod,
|
||||
N: Copy,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "abomonation-serialize")]
|
||||
impl<N: Scalar> Abomonation for Quaternion<N>
|
||||
where
|
||||
|
@ -1542,6 +1553,17 @@ where
|
|||
pub fn inverse_transform_unit_vector(&self, v: &Unit<Vector3<N>>) -> Unit<Vector3<N>> {
|
||||
self.inverse() * v
|
||||
}
|
||||
|
||||
/// Appends to `self` a rotation given in the axis-angle form, using a linearized formulation.
|
||||
///
|
||||
/// This is faster, but approximate, way to compute `UnitQuaternion::new(axisangle) * self`.
|
||||
#[inline]
|
||||
pub fn append_axisangle_linearized(&self, axisangle: &Vector3<N>) -> Self {
|
||||
let half: N = crate::convert(0.5);
|
||||
let q1 = self.into_inner();
|
||||
let q2 = Quaternion::from_imag(axisangle * half);
|
||||
Unit::new_normalize(q1 + q2 * q1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: RealField> Default for UnitQuaternion<N> {
|
||||
|
|
|
@ -160,7 +160,7 @@ where
|
|||
Owned<N, U4>: Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
Self::new(
|
||||
N::arbitrary(g),
|
||||
N::arbitrary(g),
|
||||
|
@ -266,6 +266,17 @@ where
|
|||
Self::new_unchecked(q)
|
||||
}
|
||||
|
||||
/// Builds an unit quaternion from a basis assumed to be orthonormal.
|
||||
///
|
||||
/// In order to get a valid unit-quaternion, the input must be an
|
||||
/// orthonormal basis, i.e., all vectors are normalized, and the are
|
||||
/// all orthogonal to each other. These invariants are not checked
|
||||
/// by this method.
|
||||
pub fn from_basis_unchecked(basis: &[Vector3<N>; 3]) -> Self {
|
||||
let rot = Rotation3::from_basis_unchecked(basis);
|
||||
Self::from_rotation_matrix(&rot)
|
||||
}
|
||||
|
||||
/// Builds an unit quaternion from a rotation matrix.
|
||||
///
|
||||
/// # Example
|
||||
|
@ -834,7 +845,7 @@ where
|
|||
Owned<N, U3>: Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
let axisangle = Vector3::arbitrary(g);
|
||||
Self::from_scaled_axis(axisangle)
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ use crate::base::dimension::U3;
|
|||
use crate::base::{Matrix3, Matrix4, Scalar, Vector4};
|
||||
use crate::geometry::{
|
||||
AbstractRotation, Isometry, Quaternion, Rotation, Rotation3, Similarity, SuperTCategoryOf,
|
||||
TAffine, Transform, Translation, UnitQuaternion,
|
||||
TAffine, Transform, Translation, UnitDualQuaternion, UnitQuaternion,
|
||||
};
|
||||
|
||||
/*
|
||||
|
@ -21,6 +21,7 @@ use crate::geometry::{
|
|||
* UnitQuaternion -> UnitQuaternion
|
||||
* UnitQuaternion -> Rotation<U3>
|
||||
* UnitQuaternion -> Isometry<U3>
|
||||
* UnitQuaternion -> UnitDualQuaternion
|
||||
* UnitQuaternion -> Similarity<U3>
|
||||
* UnitQuaternion -> Transform<U3>
|
||||
* UnitQuaternion -> Matrix<U4> (homogeneous)
|
||||
|
@ -121,6 +122,28 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for UnitQuaternion<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
N2: RealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> UnitDualQuaternion<N2> {
|
||||
let q: UnitQuaternion<N2> = crate::convert_ref(self);
|
||||
UnitDualQuaternion::from_rotation(q)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
|
||||
dq.translation().vector.is_zero()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
|
||||
crate::convert_unchecked(dq.rotation())
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2, R> SubsetOf<Similarity<N2, U3, R>> for UnitQuaternion<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
|
|
|
@ -12,7 +12,7 @@ use crate::base::{DefaultAllocator, Matrix2, Matrix3, Matrix4, MatrixN, Scalar};
|
|||
|
||||
use crate::geometry::{
|
||||
AbstractRotation, Isometry, Rotation, Rotation2, Rotation3, Similarity, SuperTCategoryOf,
|
||||
TAffine, Transform, Translation, UnitComplex, UnitQuaternion,
|
||||
TAffine, Transform, Translation, UnitComplex, UnitDualQuaternion, UnitQuaternion,
|
||||
};
|
||||
|
||||
/*
|
||||
|
@ -21,6 +21,7 @@ use crate::geometry::{
|
|||
*
|
||||
* Rotation -> Rotation
|
||||
* Rotation3 -> UnitQuaternion
|
||||
* Rotation3 -> UnitDualQuaternion
|
||||
* Rotation2 -> UnitComplex
|
||||
* Rotation -> Isometry
|
||||
* Rotation -> Similarity
|
||||
|
@ -75,6 +76,31 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for Rotation3<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
N2: RealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> UnitDualQuaternion<N2> {
|
||||
let q = UnitQuaternion::<N1>::from_rotation_matrix(self);
|
||||
let dq = UnitDualQuaternion::from_rotation(q);
|
||||
dq.to_superset()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
|
||||
crate::is_convertible::<_, UnitQuaternion<N1>>(&dq.rotation())
|
||||
&& dq.translation().vector.is_zero()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
|
||||
let dq: UnitDualQuaternion<N1> = crate::convert_ref_unchecked(dq);
|
||||
dq.rotation().to_rotation_matrix()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<UnitComplex<N2>> for Rotation2<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
|
|
|
@ -12,7 +12,7 @@ use std::ops::Neg;
|
|||
|
||||
use crate::base::dimension::{U1, U2, U3};
|
||||
use crate::base::storage::Storage;
|
||||
use crate::base::{Matrix2, Matrix3, MatrixN, Unit, Vector, Vector1, Vector3, VectorN};
|
||||
use crate::base::{Matrix2, Matrix3, MatrixN, Unit, Vector, Vector1, Vector2, Vector3, VectorN};
|
||||
|
||||
use crate::geometry::{Rotation2, Rotation3, UnitComplex, UnitQuaternion};
|
||||
|
||||
|
@ -53,6 +53,17 @@ impl<N: SimdRealField> Rotation2<N> {
|
|||
|
||||
/// # Construction from an existing 2D matrix or rotations
|
||||
impl<N: SimdRealField> Rotation2<N> {
|
||||
/// Builds a rotation from a basis assumed to be orthonormal.
|
||||
///
|
||||
/// In order to get a valid unit-quaternion, the input must be an
|
||||
/// orthonormal basis, i.e., all vectors are normalized, and the are
|
||||
/// all orthogonal to each other. These invariants are not checked
|
||||
/// by this method.
|
||||
pub fn from_basis_unchecked(basis: &[Vector2<N>; 2]) -> Self {
|
||||
let mat = Matrix2::from_columns(&basis[..]);
|
||||
Self::from_matrix_unchecked(mat)
|
||||
}
|
||||
|
||||
/// Builds a rotation matrix by extracting the rotation part of the given transformation `m`.
|
||||
///
|
||||
/// This is an iterative method. See `.from_matrix_eps` to provide mover
|
||||
|
@ -264,7 +275,7 @@ where
|
|||
Owned<N, U2, U2>: Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
Self::new(N::arbitrary(g))
|
||||
}
|
||||
}
|
||||
|
@ -655,6 +666,17 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Builds a rotation from a basis assumed to be orthonormal.
|
||||
///
|
||||
/// In order to get a valid unit-quaternion, the input must be an
|
||||
/// orthonormal basis, i.e., all vectors are normalized, and the are
|
||||
/// all orthogonal to each other. These invariants are not checked
|
||||
/// by this method.
|
||||
pub fn from_basis_unchecked(basis: &[Vector3<N>; 3]) -> Self {
|
||||
let mat = Matrix3::from_columns(&basis[..]);
|
||||
Self::from_matrix_unchecked(mat)
|
||||
}
|
||||
|
||||
/// Builds a rotation matrix by extracting the rotation part of the given transformation `m`.
|
||||
///
|
||||
/// This is an iterative method. See `.from_matrix_eps` to provide mover
|
||||
|
@ -939,7 +961,7 @@ where
|
|||
Owned<N, U3>: Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(g: &mut G) -> Self {
|
||||
fn arbitrary(g: &mut Gen) -> Self {
|
||||
Self::new(VectorN::arbitrary(g))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -114,7 +114,7 @@ where
|
|||
Owned<N, D>: Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(rng: &mut G) -> Self {
|
||||
fn arbitrary(rng: &mut Gen) -> Self {
|
||||
let mut s: N = Arbitrary::arbitrary(rng);
|
||||
while s.is_zero() {
|
||||
s = Arbitrary::arbitrary(rng)
|
||||
|
|
|
@ -61,13 +61,13 @@ where
|
|||
}
|
||||
|
||||
#[cfg(feature = "arbitrary")]
|
||||
impl<N: Scalar + Arbitrary, D: DimName> Arbitrary for Translation<N, D>
|
||||
impl<N: Scalar + Arbitrary + Send, D: DimName> Arbitrary for Translation<N, D>
|
||||
where
|
||||
DefaultAllocator: Allocator<N, D>,
|
||||
Owned<N, D>: Send,
|
||||
{
|
||||
#[inline]
|
||||
fn arbitrary<G: Gen>(rng: &mut G) -> Self {
|
||||
fn arbitrary(rng: &mut Gen) -> Self {
|
||||
let v: VectorN<N, D> = Arbitrary::arbitrary(rng);
|
||||
Self::from(v)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ use crate::base::{DefaultAllocator, MatrixN, Scalar, VectorN};
|
|||
|
||||
use crate::geometry::{
|
||||
AbstractRotation, Isometry, Similarity, SuperTCategoryOf, TAffine, Transform, Translation,
|
||||
Translation3, UnitDualQuaternion, UnitQuaternion,
|
||||
};
|
||||
|
||||
/*
|
||||
|
@ -17,6 +18,7 @@ use crate::geometry::{
|
|||
*
|
||||
* Translation -> Translation
|
||||
* Translation -> Isometry
|
||||
* Translation3 -> UnitDualQuaternion
|
||||
* Translation -> Similarity
|
||||
* Translation -> Transform
|
||||
* Translation -> Matrix (homogeneous)
|
||||
|
@ -69,6 +71,30 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<N1, N2> SubsetOf<UnitDualQuaternion<N2>> for Translation3<N1>
|
||||
where
|
||||
N1: RealField,
|
||||
N2: RealField + SupersetOf<N1>,
|
||||
{
|
||||
#[inline]
|
||||
fn to_superset(&self) -> UnitDualQuaternion<N2> {
|
||||
let dq = UnitDualQuaternion::<N1>::from_parts(self.clone(), UnitQuaternion::identity());
|
||||
dq.to_superset()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_in_subset(dq: &UnitDualQuaternion<N2>) -> bool {
|
||||
crate::is_convertible::<_, Translation<N1, _>>(&dq.translation())
|
||||
&& dq.rotation() == UnitQuaternion::identity()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn from_superset_unchecked(dq: &UnitDualQuaternion<N2>) -> Self {
|
||||
let dq: UnitDualQuaternion<N1> = crate::convert_ref_unchecked(dq);
|
||||
dq.translation()
|
||||
}
|
||||
}
|
||||
|
||||
impl<N1, N2, D: DimName, R> SubsetOf<Similarity<N2, D, R>> for Translation<N1, D>
|
||||
where
|
||||
N1: RealField,
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue