00001 00030 #ifndef _MSC_VER 00031 # include <itpp/config.h> 00032 #else 00033 # include <itpp/config_msvc.h> 00034 #endif 00035 00036 #if defined(HAVE_LAPACK) 00037 # include <itpp/base/algebra/lapack.h> 00038 #endif 00039 00040 #include <itpp/base/algebra/ls_solve.h> 00041 00042 00043 namespace itpp 00044 { 00045 00046 // ----------- ls_solve_chol ----------------------------------------------------------- 00047 00048 #if defined(HAVE_LAPACK) 00049 00050 bool ls_solve_chol(const mat &A, const vec &b, vec &x) 00051 { 00052 int n, lda, ldb, nrhs, info; 00053 n = lda = ldb = A.rows(); 00054 nrhs = 1; 00055 char uplo = 'U'; 00056 00057 it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00058 it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!"); 00059 00060 ivec ipiv(n); 00061 x = b; 00062 mat Chol = A; 00063 00064 dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info); 00065 00066 return (info == 0); 00067 } 00068 00069 00070 bool ls_solve_chol(const mat &A, const mat &B, mat &X) 00071 { 00072 int n, lda, ldb, nrhs, info; 00073 n = lda = ldb = A.rows(); 00074 nrhs = B.cols(); 00075 char uplo = 'U'; 00076 00077 it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00078 it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!"); 00079 00080 ivec ipiv(n); 00081 X = B; 00082 mat Chol = A; 00083 00084 dposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info); 00085 00086 return (info == 0); 00087 } 00088 00089 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x) 00090 { 00091 int n, lda, ldb, nrhs, info; 00092 n = lda = ldb = A.rows(); 00093 nrhs = 1; 00094 char uplo = 'U'; 00095 00096 it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00097 it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!"); 00098 00099 ivec ipiv(n); 00100 x = b; 00101 cmat Chol = A; 00102 00103 zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, x._data(), &ldb, &info); 00104 00105 return (info == 0); 00106 } 00107 00108 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X) 00109 { 00110 int n, lda, ldb, nrhs, info; 00111 n = lda = ldb = A.rows(); 00112 nrhs = B.cols(); 00113 char uplo = 'U'; 00114 00115 it_assert_debug(A.cols() == n, "ls_solve_chol: System-matrix is not square"); 00116 it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!"); 00117 00118 ivec ipiv(n); 00119 X = B; 00120 cmat Chol = A; 00121 00122 zposv_(&uplo, &n, &nrhs, Chol._data(), &lda, X._data(), &ldb, &info); 00123 00124 return (info == 0); 00125 } 00126 00127 #else 00128 00129 bool ls_solve_chol(const mat &A, const vec &b, vec &x) 00130 { 00131 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00132 return false; 00133 } 00134 00135 bool ls_solve_chol(const mat &A, const mat &B, mat &X) 00136 { 00137 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00138 return false; 00139 } 00140 00141 bool ls_solve_chol(const cmat &A, const cvec &b, cvec &x) 00142 { 00143 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00144 return false; 00145 } 00146 00147 bool ls_solve_chol(const cmat &A, const cmat &B, cmat &X) 00148 { 00149 it_error("LAPACK library is needed to use ls_solve_chol() function"); 00150 return false; 00151 } 00152 00153 #endif // HAVE_LAPACK 00154 00155 vec ls_solve_chol(const mat &A, const vec &b) 00156 { 00157 vec x; 00158 bool info; 00159 info = ls_solve_chol(A, b, x); 00160 it_assert_debug(info, "ls_solve_chol: Failed solving the system"); 00161 return x; 00162 } 00163 00164 mat ls_solve_chol(const mat &A, const mat &B) 00165 { 00166 mat X; 00167 bool info; 00168 info = ls_solve_chol(A, B, X); 00169 it_assert_debug(info, "ls_solve_chol: Failed solving the system"); 00170 return X; 00171 } 00172 00173 cvec ls_solve_chol(const cmat &A, const cvec &b) 00174 { 00175 cvec x; 00176 bool info; 00177 info = ls_solve_chol(A, b, x); 00178 it_assert_debug(info, "ls_solve_chol: Failed solving the system"); 00179 return x; 00180 } 00181 00182 cmat ls_solve_chol(const cmat &A, const cmat &B) 00183 { 00184 cmat X; 00185 bool info; 00186 info = ls_solve_chol(A, B, X); 00187 it_assert_debug(info, "ls_solve_chol: Failed solving the system"); 00188 return X; 00189 } 00190 00191 00192 // --------- ls_solve --------------------------------------------------------------- 00193 #if defined(HAVE_LAPACK) 00194 00195 bool ls_solve(const mat &A, const vec &b, vec &x) 00196 { 00197 int n, lda, ldb, nrhs, info; 00198 n = lda = ldb = A.rows(); 00199 nrhs = 1; 00200 00201 it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square"); 00202 it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!"); 00203 00204 ivec ipiv(n); 00205 x = b; 00206 mat LU = A; 00207 00208 dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info); 00209 00210 return (info == 0); 00211 } 00212 00213 bool ls_solve(const mat &A, const mat &B, mat &X) 00214 { 00215 int n, lda, ldb, nrhs, info; 00216 n = lda = ldb = A.rows(); 00217 nrhs = B.cols(); 00218 00219 it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square"); 00220 it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!"); 00221 00222 ivec ipiv(n); 00223 X = B; 00224 mat LU = A; 00225 00226 dgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info); 00227 00228 return (info == 0); 00229 } 00230 00231 bool ls_solve(const cmat &A, const cvec &b, cvec &x) 00232 { 00233 int n, lda, ldb, nrhs, info; 00234 n = lda = ldb = A.rows(); 00235 nrhs = 1; 00236 00237 it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square"); 00238 it_assert_debug(n == b.size(), "The number of rows in A must equal the length of b!"); 00239 00240 ivec ipiv(n); 00241 x = b; 00242 cmat LU = A; 00243 00244 zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), x._data(), &ldb, &info); 00245 00246 return (info == 0); 00247 } 00248 00249 bool ls_solve(const cmat &A, const cmat &B, cmat &X) 00250 { 00251 int n, lda, ldb, nrhs, info; 00252 n = lda = ldb = A.rows(); 00253 nrhs = B.cols(); 00254 00255 it_assert_debug(A.cols() == n, "ls_solve: System-matrix is not square"); 00256 it_assert_debug(n == B.rows(), "The number of rows in A must equal the length of B!"); 00257 00258 ivec ipiv(n); 00259 X = B; 00260 cmat LU = A; 00261 00262 zgesv_(&n, &nrhs, LU._data(), &lda, ipiv._data(), X._data(), &ldb, &info); 00263 00264 return (info == 0); 00265 } 00266 00267 #else 00268 00269 bool ls_solve(const mat &A, const vec &b, vec &x) 00270 { 00271 it_error("LAPACK library is needed to use ls_solve() function"); 00272 return false; 00273 } 00274 00275 bool ls_solve(const mat &A, const mat &B, mat &X) 00276 { 00277 it_error("LAPACK library is needed to use ls_solve() function"); 00278 return false; 00279 } 00280 00281 bool ls_solve(const cmat &A, const cvec &b, cvec &x) 00282 { 00283 it_error("LAPACK library is needed to use ls_solve() function"); 00284 return false; 00285 } 00286 00287 bool ls_solve(const cmat &A, const cmat &B, cmat &X) 00288 { 00289 it_error("LAPACK library is needed to use ls_solve() function"); 00290 return false; 00291 } 00292 00293 #endif // HAVE_LAPACK 00294 00295 vec ls_solve(const mat &A, const vec &b) 00296 { 00297 vec x; 00298 bool info; 00299 info = ls_solve(A, b, x); 00300 it_assert_debug(info, "ls_solve: Failed solving the system"); 00301 return x; 00302 } 00303 00304 mat ls_solve(const mat &A, const mat &B) 00305 { 00306 mat X; 00307 bool info; 00308 info = ls_solve(A, B, X); 00309 it_assert_debug(info, "ls_solve: Failed solving the system"); 00310 return X; 00311 } 00312 00313 cvec ls_solve(const cmat &A, const cvec &b) 00314 { 00315 cvec x; 00316 bool info; 00317 info = ls_solve(A, b, x); 00318 it_assert_debug(info, "ls_solve: Failed solving the system"); 00319 return x; 00320 } 00321 00322 cmat ls_solve(const cmat &A, const cmat &B) 00323 { 00324 cmat X; 00325 bool info; 00326 info = ls_solve(A, B, X); 00327 it_assert_debug(info, "ls_solve: Failed solving the system"); 00328 return X; 00329 } 00330 00331 00332 // ----------------- ls_solve_od ------------------------------------------------------------------ 00333 #if defined(HAVE_LAPACK) 00334 00335 bool ls_solve_od(const mat &A, const vec &b, vec &x) 00336 { 00337 int m, n, lda, ldb, nrhs, lwork, info; 00338 char trans = 'N'; 00339 m = lda = ldb = A.rows(); 00340 n = A.cols(); 00341 nrhs = 1; 00342 lwork = n + std::max(m, nrhs); 00343 00344 it_assert_debug(m >= n, "The system is under-determined!"); 00345 it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!"); 00346 00347 vec work(lwork); 00348 x = b; 00349 mat QR = A; 00350 00351 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00352 x.set_size(n, true); 00353 00354 return (info == 0); 00355 } 00356 00357 bool ls_solve_od(const mat &A, const mat &B, mat &X) 00358 { 00359 int m, n, lda, ldb, nrhs, lwork, info; 00360 char trans = 'N'; 00361 m = lda = ldb = A.rows(); 00362 n = A.cols(); 00363 nrhs = B.cols(); 00364 lwork = n + std::max(m, nrhs); 00365 00366 it_assert_debug(m >= n, "The system is under-determined!"); 00367 it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!"); 00368 00369 vec work(lwork); 00370 X = B; 00371 mat QR = A; 00372 00373 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00374 X.set_size(n, nrhs, true); 00375 00376 return (info == 0); 00377 } 00378 00379 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x) 00380 { 00381 int m, n, lda, ldb, nrhs, lwork, info; 00382 char trans = 'N'; 00383 m = lda = ldb = A.rows(); 00384 n = A.cols(); 00385 nrhs = 1; 00386 lwork = n + std::max(m, nrhs); 00387 00388 it_assert_debug(m >= n, "The system is under-determined!"); 00389 it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!"); 00390 00391 cvec work(lwork); 00392 x = b; 00393 cmat QR = A; 00394 00395 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00396 x.set_size(n, true); 00397 00398 return (info == 0); 00399 } 00400 00401 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X) 00402 { 00403 int m, n, lda, ldb, nrhs, lwork, info; 00404 char trans = 'N'; 00405 m = lda = ldb = A.rows(); 00406 n = A.cols(); 00407 nrhs = B.cols(); 00408 lwork = n + std::max(m, nrhs); 00409 00410 it_assert_debug(m >= n, "The system is under-determined!"); 00411 it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!"); 00412 00413 cvec work(lwork); 00414 X = B; 00415 cmat QR = A; 00416 00417 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00418 X.set_size(n, nrhs, true); 00419 00420 return (info == 0); 00421 } 00422 00423 #else 00424 00425 bool ls_solve_od(const mat &A, const vec &b, vec &x) 00426 { 00427 it_error("LAPACK library is needed to use ls_solve_od() function"); 00428 return false; 00429 } 00430 00431 bool ls_solve_od(const mat &A, const mat &B, mat &X) 00432 { 00433 it_error("LAPACK library is needed to use ls_solve_od() function"); 00434 return false; 00435 } 00436 00437 bool ls_solve_od(const cmat &A, const cvec &b, cvec &x) 00438 { 00439 it_error("LAPACK library is needed to use ls_solve_od() function"); 00440 return false; 00441 } 00442 00443 bool ls_solve_od(const cmat &A, const cmat &B, cmat &X) 00444 { 00445 it_error("LAPACK library is needed to use ls_solve_od() function"); 00446 return false; 00447 } 00448 00449 #endif // HAVE_LAPACK 00450 00451 vec ls_solve_od(const mat &A, const vec &b) 00452 { 00453 vec x; 00454 bool info; 00455 info = ls_solve_od(A, b, x); 00456 it_assert_debug(info, "ls_solve_od: Failed solving the system"); 00457 return x; 00458 } 00459 00460 mat ls_solve_od(const mat &A, const mat &B) 00461 { 00462 mat X; 00463 bool info; 00464 info = ls_solve_od(A, B, X); 00465 it_assert_debug(info, "ls_solve_od: Failed solving the system"); 00466 return X; 00467 } 00468 00469 cvec ls_solve_od(const cmat &A, const cvec &b) 00470 { 00471 cvec x; 00472 bool info; 00473 info = ls_solve_od(A, b, x); 00474 it_assert_debug(info, "ls_solve_od: Failed solving the system"); 00475 return x; 00476 } 00477 00478 cmat ls_solve_od(const cmat &A, const cmat &B) 00479 { 00480 cmat X; 00481 bool info; 00482 info = ls_solve_od(A, B, X); 00483 it_assert_debug(info, "ls_solve_od: Failed solving the system"); 00484 return X; 00485 } 00486 00487 // ------------------- ls_solve_ud ----------------------------------------------------------- 00488 #if defined(HAVE_LAPACK) 00489 00490 bool ls_solve_ud(const mat &A, const vec &b, vec &x) 00491 { 00492 int m, n, lda, ldb, nrhs, lwork, info; 00493 char trans = 'N'; 00494 m = lda = A.rows(); 00495 n = A.cols(); 00496 ldb = n; 00497 nrhs = 1; 00498 lwork = m + std::max(n, nrhs); 00499 00500 it_assert_debug(m < n, "The system is over-determined!"); 00501 it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!"); 00502 00503 vec work(lwork); 00504 x = b; 00505 x.set_size(n, true); 00506 mat QR = A; 00507 00508 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00509 00510 return (info == 0); 00511 } 00512 00513 bool ls_solve_ud(const mat &A, const mat &B, mat &X) 00514 { 00515 int m, n, lda, ldb, nrhs, lwork, info; 00516 char trans = 'N'; 00517 m = lda = A.rows(); 00518 n = A.cols(); 00519 ldb = n; 00520 nrhs = B.cols(); 00521 lwork = m + std::max(n, nrhs); 00522 00523 it_assert_debug(m < n, "The system is over-determined!"); 00524 it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!"); 00525 00526 vec work(lwork); 00527 X = B; 00528 X.set_size(n, std::max(m, nrhs), true); 00529 mat QR = A; 00530 00531 dgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00532 X.set_size(n, nrhs, true); 00533 00534 return (info == 0); 00535 } 00536 00537 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x) 00538 { 00539 int m, n, lda, ldb, nrhs, lwork, info; 00540 char trans = 'N'; 00541 m = lda = A.rows(); 00542 n = A.cols(); 00543 ldb = n; 00544 nrhs = 1; 00545 lwork = m + std::max(n, nrhs); 00546 00547 it_assert_debug(m < n, "The system is over-determined!"); 00548 it_assert_debug(m == b.size(), "The number of rows in A must equal the length of b!"); 00549 00550 cvec work(lwork); 00551 x = b; 00552 x.set_size(n, true); 00553 cmat QR = A; 00554 00555 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, x._data(), &ldb, work._data(), &lwork, &info); 00556 00557 return (info == 0); 00558 } 00559 00560 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X) 00561 { 00562 int m, n, lda, ldb, nrhs, lwork, info; 00563 char trans = 'N'; 00564 m = lda = A.rows(); 00565 n = A.cols(); 00566 ldb = n; 00567 nrhs = B.cols(); 00568 lwork = m + std::max(n, nrhs); 00569 00570 it_assert_debug(m < n, "The system is over-determined!"); 00571 it_assert_debug(m == B.rows(), "The number of rows in A must equal the length of b!"); 00572 00573 cvec work(lwork); 00574 X = B; 00575 X.set_size(n, std::max(m, nrhs), true); 00576 cmat QR = A; 00577 00578 zgels_(&trans, &m, &n, &nrhs, QR._data(), &lda, X._data(), &ldb, work._data(), &lwork, &info); 00579 X.set_size(n, nrhs, true); 00580 00581 return (info == 0); 00582 } 00583 00584 #else 00585 00586 bool ls_solve_ud(const mat &A, const vec &b, vec &x) 00587 { 00588 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00589 return false; 00590 } 00591 00592 bool ls_solve_ud(const mat &A, const mat &B, mat &X) 00593 { 00594 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00595 return false; 00596 } 00597 00598 bool ls_solve_ud(const cmat &A, const cvec &b, cvec &x) 00599 { 00600 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00601 return false; 00602 } 00603 00604 bool ls_solve_ud(const cmat &A, const cmat &B, cmat &X) 00605 { 00606 it_error("LAPACK library is needed to use ls_solve_ud() function"); 00607 return false; 00608 } 00609 00610 #endif // HAVE_LAPACK 00611 00612 00613 vec ls_solve_ud(const mat &A, const vec &b) 00614 { 00615 vec x; 00616 bool info; 00617 info = ls_solve_ud(A, b, x); 00618 it_assert_debug(info, "ls_solve_ud: Failed solving the system"); 00619 return x; 00620 } 00621 00622 mat ls_solve_ud(const mat &A, const mat &B) 00623 { 00624 mat X; 00625 bool info; 00626 info = ls_solve_ud(A, B, X); 00627 it_assert_debug(info, "ls_solve_ud: Failed solving the system"); 00628 return X; 00629 } 00630 00631 cvec ls_solve_ud(const cmat &A, const cvec &b) 00632 { 00633 cvec x; 00634 bool info; 00635 info = ls_solve_ud(A, b, x); 00636 it_assert_debug(info, "ls_solve_ud: Failed solving the system"); 00637 return x; 00638 } 00639 00640 cmat ls_solve_ud(const cmat &A, const cmat &B) 00641 { 00642 cmat X; 00643 bool info; 00644 info = ls_solve_ud(A, B, X); 00645 it_assert_debug(info, "ls_solve_ud: Failed solving the system"); 00646 return X; 00647 } 00648 00649 00650 // ---------------------- backslash ----------------------------------------- 00651 00652 bool backslash(const mat &A, const vec &b, vec &x) 00653 { 00654 int m = A.rows(), n = A.cols(); 00655 bool info; 00656 00657 if (m == n) 00658 info = ls_solve(A, b, x); 00659 else if (m > n) 00660 info = ls_solve_od(A, b, x); 00661 else 00662 info = ls_solve_ud(A, b, x); 00663 00664 return info; 00665 } 00666 00667 00668 vec backslash(const mat &A, const vec &b) 00669 { 00670 vec x; 00671 bool info; 00672 info = backslash(A, b, x); 00673 it_assert_debug(info, "backslash(): solution was not found"); 00674 return x; 00675 } 00676 00677 00678 bool backslash(const mat &A, const mat &B, mat &X) 00679 { 00680 int m = A.rows(), n = A.cols(); 00681 bool info; 00682 00683 if (m == n) 00684 info = ls_solve(A, B, X); 00685 else if (m > n) 00686 info = ls_solve_od(A, B, X); 00687 else 00688 info = ls_solve_ud(A, B, X); 00689 00690 return info; 00691 } 00692 00693 00694 mat backslash(const mat &A, const mat &B) 00695 { 00696 mat X; 00697 bool info; 00698 info = backslash(A, B, X); 00699 it_assert_debug(info, "backslash(): solution was not found"); 00700 return X; 00701 } 00702 00703 00704 bool backslash(const cmat &A, const cvec &b, cvec &x) 00705 { 00706 int m = A.rows(), n = A.cols(); 00707 bool info; 00708 00709 if (m == n) 00710 info = ls_solve(A, b, x); 00711 else if (m > n) 00712 info = ls_solve_od(A, b, x); 00713 else 00714 info = ls_solve_ud(A, b, x); 00715 00716 return info; 00717 } 00718 00719 00720 cvec backslash(const cmat &A, const cvec &b) 00721 { 00722 cvec x; 00723 bool info; 00724 info = backslash(A, b, x); 00725 it_assert_debug(info, "backslash(): solution was not found"); 00726 return x; 00727 } 00728 00729 00730 bool backslash(const cmat &A, const cmat &B, cmat &X) 00731 { 00732 int m = A.rows(), n = A.cols(); 00733 bool info; 00734 00735 if (m == n) 00736 info = ls_solve(A, B, X); 00737 else if (m > n) 00738 info = ls_solve_od(A, B, X); 00739 else 00740 info = ls_solve_ud(A, B, X); 00741 00742 return info; 00743 } 00744 00745 cmat backslash(const cmat &A, const cmat &B) 00746 { 00747 cmat X; 00748 bool info; 00749 info = backslash(A, B, X); 00750 it_assert_debug(info, "backslash(): solution was not found"); 00751 return X; 00752 } 00753 00754 00755 // -------------------------------------------------------------------------- 00756 00757 vec forward_substitution(const mat &L, const vec &b) 00758 { 00759 int n = L.rows(); 00760 vec x(n); 00761 00762 forward_substitution(L, b, x); 00763 00764 return x; 00765 } 00766 00767 void forward_substitution(const mat &L, const vec &b, vec &x) 00768 { 00769 it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size(), 00770 "forward_substitution: dimension mismatch"); 00771 int n = L.rows(), i, j; 00772 double temp; 00773 00774 x(0) = b(0) / L(0, 0); 00775 for (i = 1;i < n;i++) { 00776 // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); but this is to slow. 00777 //i_pos=i*L._row_offset(); 00778 temp = 0; 00779 for (j = 0; j < i; j++) { 00780 temp += L._elem(i, j) * x(j); 00781 //temp+=L._data()[i_pos+j]*x(j); 00782 } 00783 x(i) = (b(i) - temp) / L._elem(i, i); 00784 //x(i)=(b(i)-temp)/L._data()[i_pos+i]; 00785 } 00786 } 00787 00788 vec forward_substitution(const mat &L, int p, const vec &b) 00789 { 00790 int n = L.rows(); 00791 vec x(n); 00792 00793 forward_substitution(L, p, b, x); 00794 00795 return x; 00796 } 00797 00798 void forward_substitution(const mat &L, int p, const vec &b, vec &x) 00799 { 00800 it_assert(L.rows() == L.cols() && L.cols() == b.size() && b.size() == x.size() && p <= L.rows() / 2, 00801 "forward_substitution: dimension mismatch"); 00802 int n = L.rows(), i, j; 00803 00804 x = b; 00805 00806 for (j = 0;j < n;j++) { 00807 x(j) /= L(j, j); 00808 for (i = j + 1;i < std::min(j + p + 1, n);i++) { 00809 x(i) -= L(i, j) * x(j); 00810 } 00811 } 00812 } 00813 00814 vec backward_substitution(const mat &U, const vec &b) 00815 { 00816 vec x(U.rows()); 00817 backward_substitution(U, b, x); 00818 00819 return x; 00820 } 00821 00822 void backward_substitution(const mat &U, const vec &b, vec &x) 00823 { 00824 it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size(), 00825 "backward_substitution: dimension mismatch"); 00826 int n = U.rows(), i, j; 00827 double temp; 00828 00829 x(n - 1) = b(n - 1) / U(n - 1, n - 1); 00830 for (i = n - 2; i >= 0; i--) { 00831 // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); but this is too slow. 00832 temp = 0; 00833 //i_pos=i*U._row_offset(); 00834 for (j = i + 1; j < n; j++) { 00835 temp += U._elem(i, j) * x(j); 00836 //temp+=U._data()[i_pos+j]*x(j); 00837 } 00838 x(i) = (b(i) - temp) / U._elem(i, i); 00839 //x(i)=(b(i)-temp)/U._data()[i_pos+i]; 00840 } 00841 } 00842 00843 vec backward_substitution(const mat &U, int q, const vec &b) 00844 { 00845 vec x(U.rows()); 00846 backward_substitution(U, q, b, x); 00847 00848 return x; 00849 } 00850 00851 void backward_substitution(const mat &U, int q, const vec &b, vec &x) 00852 { 00853 it_assert(U.rows() == U.cols() && U.cols() == b.size() && b.size() == x.size() && q <= U.rows() / 2, 00854 "backward_substitution: dimension mismatch"); 00855 int n = U.rows(), i, j; 00856 00857 x = b; 00858 00859 for (j = n - 1; j >= 0; j--) { 00860 x(j) /= U(j, j); 00861 for (i = std::max(0, j - q); i < j; i++) { 00862 x(i) -= U(i, j) * x(j); 00863 } 00864 } 00865 } 00866 00867 } // namespace itpp
Generated on Wed Dec 7 2011 03:38:19 for IT++ by Doxygen 1.7.4