BandedSystem.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #include <Eigen/Eigen>
  2. // The banded system class is used for solving
  3. // banded linear system Ax=b efficiently.
  4. // A is an N*N band matrix with lower band width lowerBw
  5. // and upper band width upperBw.
  6. // Banded LU factorization has O(N) time complexity.
  7. // 带状系统类用于解决
  8. // 有效的带状线性系统 Ax=b。
  9. // A 是一个 N*N 带矩阵,具有较低带宽 lowerBw
  10. // 和上带宽 upperBw。
  11. // 带状 LU 分解的时间复杂度为 O(N)。
  12. #ifndef BANDEDSYSTEM_H
  13. #define BANDEDSYSTEM_H
  14. class BandedSystem // 这里没太看懂
  15. {
  16. public:
  17. // The size of A, as well as the lower/upper
  18. // banded width p/q are needed
  19. inline void create(const int &n, const int &p, const int &q)
  20. {
  21. // In case of re-creating before destroying
  22. destroy();
  23. N = n;
  24. lowerBw = p;
  25. upperBw = q;
  26. int actualSize = N * (lowerBw + upperBw + 1);
  27. ptrData = new double[actualSize];
  28. std::fill_n(ptrData, actualSize, 0.0);
  29. return;
  30. }
  31. inline void destroy()
  32. {
  33. if (ptrData != nullptr)
  34. {
  35. delete[] ptrData;
  36. ptrData = nullptr;
  37. }
  38. return;
  39. }
  40. private:
  41. int N;
  42. int lowerBw;
  43. int upperBw;
  44. // Compulsory nullptr initialization here
  45. double *ptrData = nullptr;
  46. public:
  47. // Reset the matrix to zero
  48. inline void reset(void) // 这个reset有问题,只有主对角线和设置的upbound lowbound对应的位置是0,其他都是不确定的
  49. {
  50. std::fill_n(ptrData, N * (lowerBw + upperBw + 1), 0.0);
  51. return;
  52. }
  53. // The band matrix is stored as suggested in "Matrix Computation"
  54. inline const double &operator()(const int &i, const int &j) const
  55. {
  56. return ptrData[(i - j + upperBw) * N + j];
  57. }
  58. inline double &operator()(const int &i, const int &j)
  59. {
  60. return ptrData[(i - j + upperBw) * N + j];
  61. }
  62. // This function conducts banded LU factorization in place
  63. // Note that NO PIVOT is applied on the matrix "A" for efficiency!!!
  64. inline void factorizeLU()
  65. {
  66. int iM, jM;
  67. double cVl;
  68. for (int k = 0; k <= N - 2; ++k)
  69. {
  70. iM = std::min(k + lowerBw, N - 1);
  71. cVl = operator()(k, k);
  72. for (int i = k + 1; i <= iM; ++i)
  73. {
  74. if (operator()(i, k) != 0.0)
  75. {
  76. operator()(i, k) /= cVl;
  77. }
  78. }
  79. jM = std::min(k + upperBw, N - 1);
  80. for (int j = k + 1; j <= jM; ++j)
  81. {
  82. cVl = operator()(k, j);
  83. if (cVl != 0.0)
  84. {
  85. for (int i = k + 1; i <= iM; ++i)
  86. {
  87. if (operator()(i, k) != 0.0)
  88. {
  89. operator()(i, j) -= operator()(i, k) * cVl;
  90. }
  91. }
  92. }
  93. }
  94. }
  95. return;
  96. }
  97. // This function solves Ax=b, then stores x in b
  98. // The input b is required to be N*m, i.e.,
  99. // m vectors to be solved.
  100. template <typename EIGENMAT>
  101. inline void solve(EIGENMAT &b) const
  102. {
  103. int iM;
  104. for (int j = 0; j <= N - 1; ++j)
  105. {
  106. iM = std::min(j + lowerBw, N - 1);
  107. for (int i = j + 1; i <= iM; ++i)
  108. {
  109. if (operator()(i, j) != 0.0)
  110. {
  111. b.row(i) -= operator()(i, j) * b.row(j);
  112. }
  113. }
  114. }
  115. for (int j = N - 1; j >= 0; --j)
  116. {
  117. b.row(j) /= operator()(j, j);
  118. iM = std::max(0, j - upperBw);
  119. for (int i = iM; i <= j - 1; ++i)
  120. {
  121. if (operator()(i, j) != 0.0)
  122. {
  123. b.row(i) -= operator()(i, j) * b.row(j);
  124. }
  125. }
  126. }
  127. return;
  128. }
  129. // This function solves ATx=b, then stores x in b
  130. // The input b is required to be N*m, i.e.,
  131. // m vectors to be solved.
  132. template <typename EIGENMAT>
  133. inline void solveAdj(EIGENMAT &b) const
  134. {
  135. int iM;
  136. for (int j = 0; j <= N - 1; ++j)
  137. {
  138. b.row(j) /= operator()(j, j);
  139. iM = std::min(j + upperBw, N - 1);
  140. for (int i = j + 1; i <= iM; ++i)
  141. {
  142. if (operator()(j, i) != 0.0)
  143. {
  144. b.row(i) -= operator()(j, i) * b.row(j);
  145. }
  146. }
  147. }
  148. for (int j = N - 1; j >= 0; --j)
  149. {
  150. iM = std::max(0, j - lowerBw);
  151. for (int i = iM; i <= j - 1; ++i)
  152. {
  153. if (operator()(j, i) != 0.0)
  154. {
  155. b.row(i) -= operator()(j, i) * b.row(j);
  156. }
  157. }
  158. }
  159. }
  160. };
  161. #endif // BANDEDSYSTEM_H