My Project
Loading...
Searching...
No Matches
TridiagonalMatrix.hpp
Go to the documentation of this file.
1// -*- mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*-
2// vi: set et ts=4 sw=4 sts=4:
3/*
4 This file is part of the Open Porous Media project (OPM).
5
6 OPM is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 2 of the License, or
9 (at your option) any later version.
10
11 OPM is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with OPM. If not, see <http://www.gnu.org/licenses/>.
18
19 Consult the COPYING file in the top-level source directory of this
20 module for the precise wording of the license and the list of
21 copyright holders.
22*/
27#ifndef OPM_TRIDIAGONAL_MATRIX_HH
28#define OPM_TRIDIAGONAL_MATRIX_HH
29
30#include <algorithm>
31#include <cassert>
32#include <cmath>
33#include <iosfwd>
34#include <vector>
35
36namespace Opm {
37
48template <class Scalar>
50{
51 struct TridiagRow_ {
52 TridiagRow_(TridiagonalMatrix& m, size_t rowIdx)
53 : matrix_(m)
54 , rowIdx_(rowIdx)
55 {}
56
57 Scalar& operator[](size_t colIdx)
58 { return matrix_.at(rowIdx_, colIdx); }
59
60 Scalar operator[](size_t colIdx) const
61 { return matrix_.at(rowIdx_, colIdx); }
62
66 TridiagRow_& operator++()
67 { ++ rowIdx_; return *this; }
68
72 TridiagRow_& operator--()
73 { -- rowIdx_; return *this; }
74
78 bool operator==(const TridiagRow_& other) const
79 { return other.rowIdx_ == rowIdx_ && &other.matrix_ == &matrix_; }
80
84 bool operator!=(const TridiagRow_& other) const
85 { return !operator==(other); }
86
90 TridiagRow_& operator*()
91 { return *this; }
92
98 size_t index() const
99 { return rowIdx_; }
100
101 private:
102 TridiagonalMatrix& matrix_;
103 mutable size_t rowIdx_;
104 };
105
106public:
107 typedef Scalar FieldType;
108 typedef TridiagRow_ RowType;
109 typedef size_t SizeType;
110 typedef TridiagRow_ iterator;
111 typedef TridiagRow_ const_iterator;
112
113 explicit TridiagonalMatrix(size_t numRows = 0)
114 {
115 resize(numRows);
116 }
117
118 TridiagonalMatrix(size_t numRows, Scalar value)
119 {
120 resize(numRows);
121 this->operator=(value);
122 }
123
128 { *this = source; }
129
133 size_t size() const
134 { return diag_[0].size(); }
135
139 size_t rows() const
140 { return size(); }
141
145 size_t cols() const
146 { return size(); }
147
151 void resize(size_t n)
152 {
153 if (n == size())
154 return;
155
156 for (int diagIdx = 0; diagIdx < 3; ++ diagIdx)
157 diag_[diagIdx].resize(n);
158 }
159
163 Scalar& at(size_t rowIdx, size_t colIdx)
164 {
165 size_t n = size();
166
167 // special cases
168 if (n > 2) {
169 if (rowIdx == 0 && colIdx == n - 1)
170 return diag_[2][0];
171 if (rowIdx == n - 1 && colIdx == 0)
172 return diag_[0][n - 1];
173 }
174
175 size_t diagIdx = 1 + colIdx - rowIdx;
176 // make sure that the requested column is in range
177 assert(diagIdx < 3);
178 return diag_[diagIdx][colIdx];
179 }
180
184 Scalar at(size_t rowIdx, size_t colIdx) const
185 {
186 size_t n = size();
187
188 // special cases
189 if (rowIdx == 0 && colIdx == n - 1)
190 return diag_[2][0];
191 if (rowIdx == n - 1 && colIdx == 0)
192 return diag_[0][n - 1];
193
194 size_t diagIdx = 1 + colIdx - rowIdx;
195 // make sure that the requested column is in range
196 assert(diagIdx < 3);
197 return diag_[diagIdx][colIdx];
198 }
199
204 {
205 for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
206 diag_[diagIdx] = source.diag_[diagIdx];
207
208 return *this;
209 }
210
215 {
216 for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
217 diag_[diagIdx].assign(size(), value);
218
219 return *this;
220 }
221
225 iterator begin()
226 { return TridiagRow_(*this, 0); }
227
231 const_iterator begin() const
232 { return TridiagRow_(const_cast<TridiagonalMatrix&>(*this), 0); }
233
237 const_iterator end() const
238 { return TridiagRow_(const_cast<TridiagonalMatrix&>(*this), size()); }
239
243 TridiagRow_ operator[](size_t rowIdx)
244 { return TridiagRow_(*this, rowIdx); }
245
249 const TridiagRow_ operator[](size_t rowIdx) const
250 { return TridiagRow_(*this, rowIdx); }
251
256 {
257 unsigned n = size();
258 for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx) {
259 for (unsigned i = 0; i < n; ++i) {
260 diag_[diagIdx][i] *= alpha;
261 }
262 }
263
264 return *this;
265 }
266
271 {
272 alpha = 1.0/alpha;
273 unsigned n = size();
274 for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx) {
275 for (unsigned i = 0; i < n; ++i) {
276 diag_[diagIdx][i] *= alpha;
277 }
278 }
279
280 return *this;
281 }
282
287 { return axpy(-1.0, other); }
288
293 { return axpy(1.0, other); }
294
295
309 TridiagonalMatrix& axpy(Scalar alpha, const TridiagonalMatrix& other)
310 {
311 assert(size() == other.size());
312
313 unsigned n = size();
314 for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
315 for (unsigned i = 0; i < n; ++ i)
316 diag_[diagIdx][i] += alpha * other[diagIdx][i];
317
318 return *this;
319 }
320
333 template<class Vector>
334 void mv(const Vector& source, Vector& dest) const
335 {
336 assert(source.size() == size());
337 assert(dest.size() == size());
338 assert(size() > 1);
339
340 // deal with rows 1 .. n-2
341 unsigned n = size();
342 for (unsigned i = 1; i < n - 1; ++ i) {
343 dest[i] =
344 diag_[0][i - 1]*source[i-1] +
345 diag_[1][i]*source[i] +
346 diag_[2][i + 1]*source[i + 1];
347 }
348
349 // rows 0 and n-1
350 dest[0] =
351 diag_[1][0]*source[0] +
352 diag_[2][1]*source[1] +
353 diag_[2][0]*source[n - 1];
354
355 dest[n - 1] =
356 diag_[0][n-1]*source[0] +
357 diag_[0][n-2]*source[n-2] +
358 diag_[1][n-1]*source[n-1];
359 }
360
373 template<class Vector>
374 void umv(const Vector& source, Vector& dest) const
375 {
376 assert(source.size() == size());
377 assert(dest.size() == size());
378 assert(size() > 1);
379
380 // deal with rows 1 .. n-2
381 unsigned n = size();
382 for (unsigned i = 1; i < n - 1; ++ i) {
383 dest[i] +=
384 diag_[0][i - 1]*source[i-1] +
385 diag_[1][i]*source[i] +
386 diag_[2][i + 1]*source[i + 1];
387 }
388
389 // rows 0 and n-1
390 dest[0] +=
391 diag_[1][0]*source[0] +
392 diag_[2][1]*source[1] +
393 diag_[2][0]*source[n - 1];
394
395 dest[n - 1] +=
396 diag_[0][n-1]*source[0] +
397 diag_[0][n-2]*source[n-2] +
398 diag_[1][n-1]*source[n-1];
399 }
400
413 template<class Vector>
414 void mmv(const Vector& source, Vector& dest) const
415 {
416 assert(source.size() == size());
417 assert(dest.size() == size());
418 assert(size() > 1);
419
420 // deal with rows 1 .. n-2
421 unsigned n = size();
422 for (unsigned i = 1; i < n - 1; ++ i) {
423 dest[i] -=
424 diag_[0][i - 1]*source[i-1] +
425 diag_[1][i]*source[i] +
426 diag_[2][i + 1]*source[i + 1];
427 }
428
429 // rows 0 and n-1
430 dest[0] -=
431 diag_[1][0]*source[0] +
432 diag_[2][1]*source[1] +
433 diag_[2][0]*source[n - 1];
434
435 dest[n - 1] -=
436 diag_[0][n-1]*source[0] +
437 diag_[0][n-2]*source[n-2] +
438 diag_[1][n-1]*source[n-1];
439 }
440
453 template<class Vector>
454 void usmv(Scalar alpha, const Vector& source, Vector& dest) const
455 {
456 assert(source.size() == size());
457 assert(dest.size() == size());
458 assert(size() > 1);
459
460 // deal with rows 1 .. n-2
461 unsigned n = size();
462 for (unsigned i = 1; i < n - 1; ++ i) {
463 dest[i] +=
464 alpha*(
465 diag_[0][i - 1]*source[i-1] +
466 diag_[1][i]*source[i] +
467 diag_[2][i + 1]*source[i + 1]);
468 }
469
470 // rows 0 and n-1
471 dest[0] +=
472 alpha*(
473 diag_[1][0]*source[0] +
474 diag_[2][1]*source[1] +
475 diag_[2][0]*source[n - 1]);
476
477 dest[n - 1] +=
478 alpha*(
479 diag_[0][n-1]*source[0] +
480 diag_[0][n-2]*source[n-2] +
481 diag_[1][n-1]*source[n-1]);
482 }
483
496 template<class Vector>
497 void mtv(const Vector& source, Vector& dest) const
498 {
499 assert(source.size() == size());
500 assert(dest.size() == size());
501 assert(size() > 1);
502
503 // deal with rows 1 .. n-2
504 unsigned n = size();
505 for (unsigned i = 1; i < n - 1; ++ i) {
506 dest[i] =
507 diag_[2][i + 1]*source[i-1] +
508 diag_[1][i]*source[i] +
509 diag_[0][i - 1]*source[i + 1];
510 }
511
512 // rows 0 and n-1
513 dest[0] =
514 diag_[1][0]*source[0] +
515 diag_[0][1]*source[1] +
516 diag_[0][n-1]*source[n - 1];
517
518 dest[n - 1] =
519 diag_[2][0]*source[0] +
520 diag_[2][n-1]*source[n-2] +
521 diag_[1][n-1]*source[n-1];
522 }
523
536 template<class Vector>
537 void umtv(const Vector& source, Vector& dest) const
538 {
539 assert(source.size() == size());
540 assert(dest.size() == size());
541 assert(size() > 1);
542
543 // deal with rows 1 .. n-2
544 unsigned n = size();
545 for (unsigned i = 1; i < n - 1; ++ i) {
546 dest[i] +=
547 diag_[2][i + 1]*source[i-1] +
548 diag_[1][i]*source[i] +
549 diag_[0][i - 1]*source[i + 1];
550 }
551
552 // rows 0 and n-1
553 dest[0] +=
554 diag_[1][0]*source[0] +
555 diag_[0][1]*source[1] +
556 diag_[0][n-1]*source[n - 1];
557
558 dest[n - 1] +=
559 diag_[2][0]*source[0] +
560 diag_[2][n-1]*source[n-2] +
561 diag_[1][n-1]*source[n-1];
562 }
563
576 template<class Vector>
577 void mmtv (const Vector& source, Vector& dest) const
578 {
579 assert(source.size() == size());
580 assert(dest.size() == size());
581 assert(size() > 1);
582
583 // deal with rows 1 .. n-2
584 unsigned n = size();
585 for (unsigned i = 1; i < n - 1; ++ i) {
586 dest[i] -=
587 diag_[2][i + 1]*source[i-1] +
588 diag_[1][i]*source[i] +
589 diag_[0][i - 1]*source[i + 1];
590 }
591
592 // rows 0 and n-1
593 dest[0] -=
594 diag_[1][0]*source[0] +
595 diag_[0][1]*source[1] +
596 diag_[0][n-1]*source[n - 1];
597
598 dest[n - 1] -=
599 diag_[2][0]*source[0] +
600 diag_[2][n-1]*source[n-2] +
601 diag_[1][n-1]*source[n-1];
602 }
603
616 template<class Vector>
617 void usmtv(Scalar alpha, const Vector& source, Vector& dest) const
618 {
619 assert(source.size() == size());
620 assert(dest.size() == size());
621 assert(size() > 1);
622
623 // deal with rows 1 .. n-2
624 unsigned n = size();
625 for (unsigned i = 1; i < n - 1; ++ i) {
626 dest[i] +=
627 alpha*(
628 diag_[2][i + 1]*source[i-1] +
629 diag_[1][i]*source[i] +
630 diag_[0][i - 1]*source[i + 1]);
631 }
632
633 // rows 0 and n-1
634 dest[0] +=
635 alpha*(
636 diag_[1][0]*source[0] +
637 diag_[0][1]*source[1] +
638 diag_[0][n-1]*source[n - 1]);
639
640 dest[n - 1] +=
641 alpha*(
642 diag_[2][0]*source[0] +
643 diag_[2][n-1]*source[n-2] +
644 diag_[1][n-1]*source[n-1]);
645 }
646
653 Scalar frobeniusNorm() const
654 { return std::sqrt(frobeniusNormSquared()); }
655
661 Scalar frobeniusNormSquared() const
662 {
663 Scalar result = 0;
664
665 unsigned n = size();
666 for (unsigned i = 0; i < n; ++ i)
667 for (unsigned diagIdx = 0; diagIdx < 3; ++ diagIdx)
668 result += diag_[diagIdx][i];
669
670 return result;
671 }
672
678 Scalar infinityNorm() const
679 {
680 Scalar result=0;
681
682 // deal with rows 1 .. n-2
683 unsigned n = size();
684 for (unsigned i = 1; i < n - 1; ++ i) {
685 result = std::max(result,
686 std::abs(diag_[0][i - 1]) +
687 std::abs(diag_[1][i]) +
688 std::abs(diag_[2][i + 1]));
689 }
690
691 // rows 0 and n-1
692 result = std::max(result,
693 std::abs(diag_[1][0]) +
694 std::abs(diag_[2][1]) +
695 std::abs(diag_[2][0]));
696
697
698 result = std::max(result,
699 std::abs(diag_[0][n-1]) +
700 std::abs(diag_[0][n-2]) +
701 std::abs(diag_[1][n-2]));
702
703 return result;
704 }
705
712 template <class XVector, class BVector>
713 void solve(XVector& x, const BVector& b) const
714 {
715 if (size() > 2 && std::abs(diag_[2][0]) < 1e-30)
716 solveWithUpperRight_(x, b);
717 else
718 solveWithoutUpperRight_(x, b);
719 }
720
724 void print(std::ostream& os) const;
725
726private:
727 template <class XVector, class BVector>
728 void solveWithUpperRight_(XVector& x, const BVector& b) const
729 {
730 size_t n = size();
731
732 std::vector<Scalar> lowerDiag(diag_[0]), mainDiag(diag_[1]), upperDiag(diag_[2]), lastColumn(n);
733 std::vector<Scalar> bStar(n);
734 std::copy(b.begin(), b.end(), bStar.begin());
735
736 lastColumn[0] = upperDiag[0];
737
738 // forward elimination
739 for (size_t i = 1; i < n; ++i) {
740 Scalar alpha = lowerDiag[i - 1]/mainDiag[i - 1];
741
742 lowerDiag[i - 1] -= alpha * mainDiag[i - 1];
743 mainDiag[i] -= alpha * upperDiag[i];
744
745 bStar[i] -= alpha * bStar[i - 1];
746 };
747
748 // deal with the last row if the entry on the lower left is not zero
749 if (lowerDiag[n - 1] != 0.0 && size() > 2) {
750 Scalar lastRow = lowerDiag[n - 1];
751 for (size_t i = 0; i < n - 1; ++i) {
752 Scalar alpha = lastRow/mainDiag[i];
753 lastRow = - alpha*upperDiag[i + 1];
754 bStar[n - 1] -= alpha * bStar[i];
755 }
756
757 mainDiag[n-1] += lastRow;
758 }
759
760 // backward elimination
761 x[n - 1] = bStar[n - 1]/mainDiag[n-1];
762 for (int i = static_cast<int>(n) - 2; i >= 0; --i) {
763 unsigned iu = static_cast<unsigned>(i);
764 x[iu] = (bStar[iu] - x[iu + 1]*upperDiag[iu+1] - x[n-1]*lastColumn[iu])/mainDiag[iu];
765 }
766 }
767
768 template <class XVector, class BVector>
769 void solveWithoutUpperRight_(XVector& x, const BVector& b) const
770 {
771 size_t n = size();
772
773 std::vector<Scalar> lowerDiag(diag_[0]), mainDiag(diag_[1]), upperDiag(diag_[2]);
774 std::vector<Scalar> bStar(n);
775 std::copy(b.begin(), b.end(), bStar.begin());
776
777 // forward elimination
778 for (size_t i = 1; i < n; ++i) {
779 Scalar alpha = lowerDiag[i - 1]/mainDiag[i - 1];
780
781 lowerDiag[i - 1] -= alpha * mainDiag[i - 1];
782 mainDiag[i] -= alpha * upperDiag[i];
783
784 bStar[i] -= alpha * bStar[i - 1];
785 };
786
787 // deal with the last row if the entry on the lower left is not zero
788 if (lowerDiag[n - 1] != 0.0 && size() > 2) {
789 Scalar lastRow = lowerDiag[n - 1];
790 for (size_t i = 0; i < n - 1; ++i) {
791 Scalar alpha = lastRow/mainDiag[i];
792 lastRow = - alpha*upperDiag[i + 1];
793 bStar[n - 1] -= alpha * bStar[i];
794 }
795
796 mainDiag[n-1] += lastRow;
797 }
798
799 // backward elimination
800 x[n - 1] = bStar[n - 1]/mainDiag[n-1];
801 for (int i = static_cast<int>(n) - 2; i >= 0; --i) {
802 unsigned iu = static_cast<unsigned>(i);
803 x[iu] = (bStar[iu] - x[iu + 1]*upperDiag[iu+1])/mainDiag[iu];
804 }
805 }
806
807 mutable std::vector<Scalar> diag_[3];
808};
809
810} // namespace Opm
811
812template <class Scalar>
813std::ostream& operator<<(std::ostream& os, const Opm::TridiagonalMatrix<Scalar>& mat)
814{
815 mat.print(os);
816 return os;
817}
818
819#endif
Provides a tridiagonal matrix that also supports non-zero entries in the upper right and lower left.
Definition TridiagonalMatrix.hpp:50
const_iterator end() const
\begin Const iterator for the next-to-last row
Definition TridiagonalMatrix.hpp:237
TridiagonalMatrix(const TridiagonalMatrix &source)
Copy constructor.
Definition TridiagonalMatrix.hpp:127
TridiagRow_ operator[](size_t rowIdx)
Row access operator.
Definition TridiagonalMatrix.hpp:243
size_t cols() const
Return the number of columns of the matrix.
Definition TridiagonalMatrix.hpp:145
size_t rows() const
Return the number of rows of the matrix.
Definition TridiagonalMatrix.hpp:139
Scalar frobeniusNormSquared() const
Calculate the squared frobenius norm.
Definition TridiagonalMatrix.hpp:661
TridiagonalMatrix & operator/=(Scalar alpha)
Division by a Scalar.
Definition TridiagonalMatrix.hpp:270
TridiagonalMatrix & operator=(Scalar value)
Assignment operator from a Scalar.
Definition TridiagonalMatrix.hpp:214
void umtv(const Vector &source, Vector &dest) const
Transposed additive matrix-vector product.
Definition TridiagonalMatrix.hpp:537
const_iterator begin() const
\begin Const iterator for the first row
Definition TridiagonalMatrix.hpp:231
iterator begin()
\begin Iterator for the first row
Definition TridiagonalMatrix.hpp:225
void usmtv(Scalar alpha, const Vector &source, Vector &dest) const
Transposed scaled additive matrix-vector product.
Definition TridiagonalMatrix.hpp:617
TridiagonalMatrix & operator+=(const TridiagonalMatrix &other)
Addition operator.
Definition TridiagonalMatrix.hpp:292
void mv(const Vector &source, Vector &dest) const
Matrix-vector product.
Definition TridiagonalMatrix.hpp:334
Scalar infinityNorm() const
Calculate the infinity norm.
Definition TridiagonalMatrix.hpp:678
Scalar at(size_t rowIdx, size_t colIdx) const
Access an entry.
Definition TridiagonalMatrix.hpp:184
const TridiagRow_ operator[](size_t rowIdx) const
Row access operator.
Definition TridiagonalMatrix.hpp:249
void mmv(const Vector &source, Vector &dest) const
Subtractive matrix-vector product.
Definition TridiagonalMatrix.hpp:414
void usmv(Scalar alpha, const Vector &source, Vector &dest) const
Scaled additive matrix-vector product.
Definition TridiagonalMatrix.hpp:454
void mmtv(const Vector &source, Vector &dest) const
Transposed subtractive matrix-vector product.
Definition TridiagonalMatrix.hpp:577
void print(std::ostream &os) const
Print the matrix to a given output stream.
Definition TridiagonalMatrix.cpp:32
TridiagonalMatrix & operator-=(const TridiagonalMatrix &other)
Subtraction operator.
Definition TridiagonalMatrix.hpp:286
size_t size() const
Return the number of rows/columns of the matrix.
Definition TridiagonalMatrix.hpp:133
TridiagonalMatrix & axpy(Scalar alpha, const TridiagonalMatrix &other)
Multiply and add the matrix entries of another tridiagonal matrix.
Definition TridiagonalMatrix.hpp:309
TridiagonalMatrix & operator=(const TridiagonalMatrix &source)
Assignment operator from another tridiagonal matrix.
Definition TridiagonalMatrix.hpp:203
void solve(XVector &x, const BVector &b) const
Calculate the solution for a linear system of equations.
Definition TridiagonalMatrix.hpp:713
void resize(size_t n)
Change the number of rows of the matrix.
Definition TridiagonalMatrix.hpp:151
void mtv(const Vector &source, Vector &dest) const
Transposed matrix-vector product.
Definition TridiagonalMatrix.hpp:497
Scalar frobeniusNorm() const
Calculate the frobenius norm.
Definition TridiagonalMatrix.hpp:653
void umv(const Vector &source, Vector &dest) const
Additive matrix-vector product.
Definition TridiagonalMatrix.hpp:374
Scalar & at(size_t rowIdx, size_t colIdx)
Access an entry.
Definition TridiagonalMatrix.hpp:163
TridiagonalMatrix & operator*=(Scalar alpha)
Multiplication with a Scalar.
Definition TridiagonalMatrix.hpp:255
This class implements a small container which holds the transmissibility mulitpliers for all the face...
Definition Exceptions.hpp:30