pardiso_interface.c 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. #include "pardiso_interface.h"
  2. #define MKL_INT c_int
  3. // Single Dynamic library interface
  4. #define MKL_INTERFACE_LP64 0x0
  5. #define MKL_INTERFACE_ILP64 0x1
  6. // Solver Phases
  7. #define PARDISO_SYMBOLIC (11)
  8. #define PARDISO_NUMERIC (22)
  9. #define PARDISO_SOLVE (33)
  10. #define PARDISO_CLEANUP (-1)
  11. // Prototypes for Pardiso functions
  12. void pardiso(void**, // pt
  13. const c_int*, // maxfct
  14. const c_int*, // mnum
  15. const c_int*, // mtype
  16. const c_int*, // phase
  17. const c_int*, // n
  18. const c_float*, // a
  19. const c_int*, // ia
  20. const c_int*, // ja
  21. c_int*, // perm
  22. const c_int*, //nrhs
  23. c_int*, // iparam
  24. const c_int*, //msglvl
  25. c_float*, // b
  26. c_float*, // x
  27. c_int* // error
  28. );
  29. c_int mkl_set_interface_layer(c_int);
  30. c_int mkl_get_max_threads();
  31. // Free LDL Factorization structure
  32. void free_linsys_solver_pardiso(pardiso_solver *s) {
  33. if (s) {
  34. // Free pardiso solver using internal function
  35. s->phase = PARDISO_CLEANUP;
  36. pardiso (s->pt, &(s->maxfct), &(s->mnum), &(s->mtype), &(s->phase),
  37. &(s->nKKT), &(s->fdum), s->KKT_p, s->KKT_i, &(s->idum), &(s->nrhs),
  38. s->iparm, &(s->msglvl), &(s->fdum), &(s->fdum), &(s->error));
  39. if ( s->error != 0 ){
  40. #ifdef PRINTING
  41. c_eprint("Error during MKL Pardiso cleanup: %d", (int)s->error);
  42. #endif
  43. }
  44. // Check each attribute of the structure and free it if it exists
  45. if (s->KKT) csc_spfree(s->KKT);
  46. if (s->KKT_i) c_free(s->KKT_i);
  47. if (s->KKT_p) c_free(s->KKT_p);
  48. if (s->bp) c_free(s->bp);
  49. if (s->sol) c_free(s->sol);
  50. if (s->rho_inv_vec) c_free(s->rho_inv_vec);
  51. // These are required for matrix updates
  52. if (s->Pdiag_idx) c_free(s->Pdiag_idx);
  53. if (s->PtoKKT) c_free(s->PtoKKT);
  54. if (s->AtoKKT) c_free(s->AtoKKT);
  55. if (s->rhotoKKT) c_free(s->rhotoKKT);
  56. c_free(s);
  57. }
  58. }
  59. // Initialize factorization structure
  60. c_int init_linsys_solver_pardiso(pardiso_solver ** sp, const csc * P, const csc * A, c_float sigma, const c_float * rho_vec, c_int polish){
  61. c_int i; // loop counter
  62. c_int nnzKKT; // Number of nonzeros in KKT
  63. // Define Variables
  64. c_int n_plus_m; // n_plus_m dimension
  65. // Allocate private structure to store KKT factorization
  66. pardiso_solver *s;
  67. s = c_calloc(1, sizeof(pardiso_solver));
  68. *sp = s;
  69. // Size of KKT
  70. s->n = P->n;
  71. s->m = A->m;
  72. n_plus_m = s->n + s->m;
  73. s->nKKT = n_plus_m;
  74. // Sigma parameter
  75. s->sigma = sigma;
  76. // Polishing flag
  77. s->polish = polish;
  78. // Link Functions
  79. s->solve = &solve_linsys_pardiso;
  80. s->free = &free_linsys_solver_pardiso;
  81. s->update_matrices = &update_linsys_solver_matrices_pardiso;
  82. s->update_rho_vec = &update_linsys_solver_rho_vec_pardiso;
  83. // Assign type
  84. s->type = MKL_PARDISO_SOLVER;
  85. // Working vector
  86. s->bp = (c_float *)c_malloc(sizeof(c_float) * n_plus_m);
  87. // Solution vector
  88. s->sol = (c_float *)c_malloc(sizeof(c_float) * n_plus_m);
  89. // Parameter vector
  90. s->rho_inv_vec = (c_float *)c_malloc(sizeof(c_float) * n_plus_m);
  91. // Form KKT matrix
  92. if (polish){ // Called from polish()
  93. // Use s->rho_inv_vec for storing param2 = vec(delta)
  94. for (i = 0; i < A->m; i++){
  95. s->rho_inv_vec[i] = sigma;
  96. }
  97. s->KKT = form_KKT(P, A, 1, sigma, s->rho_inv_vec, OSQP_NULL, OSQP_NULL, OSQP_NULL, OSQP_NULL, OSQP_NULL);
  98. }
  99. else { // Called from ADMM algorithm
  100. // Allocate vectors of indices
  101. s->PtoKKT = c_malloc((P->p[P->n]) * sizeof(c_int));
  102. s->AtoKKT = c_malloc((A->p[A->n]) * sizeof(c_int));
  103. s->rhotoKKT = c_malloc((A->m) * sizeof(c_int));
  104. // Use s->rho_inv_vec for storing param2 = rho_inv_vec
  105. for (i = 0; i < A->m; i++){
  106. s->rho_inv_vec[i] = 1. / rho_vec[i];
  107. }
  108. s->KKT = form_KKT(P, A, 1, sigma, s->rho_inv_vec,
  109. s->PtoKKT, s->AtoKKT,
  110. &(s->Pdiag_idx), &(s->Pdiag_n), s->rhotoKKT);
  111. }
  112. // Check if matrix has been created
  113. if (!(s->KKT)) {
  114. #ifdef PRINTING
  115. c_eprint("Error in forming KKT matrix");
  116. #endif
  117. free_linsys_solver_pardiso(s);
  118. return OSQP_LINSYS_SOLVER_INIT_ERROR;
  119. } else {
  120. // Adjust indexing for Pardiso
  121. nnzKKT = s->KKT->p[s->KKT->m];
  122. s->KKT_i = c_malloc((nnzKKT) * sizeof(c_int));
  123. s->KKT_p = c_malloc((s->KKT->m + 1) * sizeof(c_int));
  124. for(i = 0; i < nnzKKT; i++){
  125. s->KKT_i[i] = s->KKT->i[i] + 1;
  126. }
  127. for(i = 0; i < n_plus_m+1; i++){
  128. s->KKT_p[i] = s->KKT->p[i] + 1;
  129. }
  130. }
  131. // Set MKL interface layer (Long integers if activated)
  132. #ifdef DLONG
  133. mkl_set_interface_layer(MKL_INTERFACE_ILP64);
  134. #else
  135. mkl_set_interface_layer(MKL_INTERFACE_LP64);
  136. #endif
  137. // Set Pardiso variables
  138. s->mtype = -2; // Real symmetric indefinite matrix
  139. s->nrhs = 1; // Number of right hand sides
  140. s->maxfct = 1; // Maximum number of numerical factorizations
  141. s->mnum = 1; // Which factorization to use
  142. s->msglvl = 0; // Do not print statistical information
  143. s->error = 0; // Initialize error flag
  144. for ( i = 0; i < 64; i++ ) {
  145. s->iparm[i] = 0; // Setup Pardiso control parameters
  146. s->pt[i] = 0; // Initialize the internal solver memory pointer
  147. }
  148. s->iparm[0] = 1; // No solver default
  149. s->iparm[1] = 3; // Fill-in reordering from OpenMP
  150. if (polish) {
  151. s->iparm[5] = 1; // Write solution into b
  152. } else {
  153. s->iparm[5] = 0; // Do NOT write solution into b
  154. }
  155. /* s->iparm[7] = 2; // Max number of iterative refinement steps */
  156. s->iparm[7] = 0; // Number of iterative refinement steps (auto, performs them only if perturbed pivots are obtained)
  157. s->iparm[9] = 13; // Perturb the pivot elements with 1E-13
  158. s->iparm[34] = 0; // Use Fortran-style indexing for indices
  159. /* s->iparm[34] = 1; // Use C-style indexing for indices */
  160. // Print number of threads
  161. s->nthreads = mkl_get_max_threads();
  162. // Reordering and symbolic factorization
  163. s->phase = PARDISO_SYMBOLIC;
  164. pardiso (s->pt, &(s->maxfct), &(s->mnum), &(s->mtype), &(s->phase),
  165. &(s->nKKT), s->KKT->x, s->KKT_p, s->KKT_i, &(s->idum), &(s->nrhs),
  166. s->iparm, &(s->msglvl), &(s->fdum), &(s->fdum), &(s->error));
  167. if ( s->error != 0 ){
  168. #ifdef PRINTING
  169. c_eprint("Error during symbolic factorization: %d", (int)s->error);
  170. #endif
  171. free_linsys_solver_pardiso(s);
  172. *sp = OSQP_NULL;
  173. return OSQP_LINSYS_SOLVER_INIT_ERROR;
  174. }
  175. // Numerical factorization
  176. s->phase = PARDISO_NUMERIC;
  177. pardiso (s->pt, &(s->maxfct), &(s->mnum), &(s->mtype), &(s->phase),
  178. &(s->nKKT), s->KKT->x, s->KKT_p, s->KKT_i, &(s->idum), &(s->nrhs),
  179. s->iparm, &(s->msglvl), &(s->fdum), &(s->fdum), &(s->error));
  180. if ( s->error ){
  181. #ifdef PRINTING
  182. c_eprint("Error during numerical factorization: %d", (int)s->error);
  183. #endif
  184. free_linsys_solver_pardiso(s);
  185. *sp = OSQP_NULL;
  186. return OSQP_LINSYS_SOLVER_INIT_ERROR;
  187. }
  188. // No error
  189. return 0;
  190. }
  191. // Returns solution to linear system Ax = b with solution stored in b
  192. c_int solve_linsys_pardiso(pardiso_solver * s, c_float * b) {
  193. c_int j;
  194. // Back substitution and iterative refinement
  195. s->phase = PARDISO_SOLVE;
  196. pardiso (s->pt, &(s->maxfct), &(s->mnum), &(s->mtype), &(s->phase),
  197. &(s->nKKT), s->KKT->x, s->KKT_p, s->KKT_i, &(s->idum), &(s->nrhs),
  198. s->iparm, &(s->msglvl), b, s->sol, &(s->error));
  199. if ( s->error != 0 ){
  200. #ifdef PRINTING
  201. c_eprint("Error during linear system solution: %d", (int)s->error);
  202. #endif
  203. return 1;
  204. }
  205. if (!(s->polish)) {
  206. /* copy x_tilde from s->sol */
  207. for (j = 0 ; j < s->n ; j++) {
  208. b[j] = s->sol[j];
  209. }
  210. /* compute z_tilde from b and s->sol */
  211. for (j = 0 ; j < s->m ; j++) {
  212. b[j + s->n] += s->rho_inv_vec[j] * s->sol[j + s->n];
  213. }
  214. }
  215. return 0;
  216. }
  217. // Update solver structure with new P and A
  218. c_int update_linsys_solver_matrices_pardiso(pardiso_solver * s, const csc *P, const csc *A) {
  219. // Update KKT matrix with new P
  220. update_KKT_P(s->KKT, P, s->PtoKKT, s->sigma, s->Pdiag_idx, s->Pdiag_n);
  221. // Update KKT matrix with new A
  222. update_KKT_A(s->KKT, A, s->AtoKKT);
  223. // Perform numerical factorization
  224. s->phase = PARDISO_NUMERIC;
  225. pardiso (s->pt, &(s->maxfct), &(s->mnum), &(s->mtype), &(s->phase),
  226. &(s->nKKT), s->KKT->x, s->KKT_p, s->KKT_i, &(s->idum), &(s->nrhs),
  227. s->iparm, &(s->msglvl), &(s->fdum), &(s->fdum), &(s->error));
  228. // Return exit flag
  229. return s->error;
  230. }
  231. c_int update_linsys_solver_rho_vec_pardiso(pardiso_solver * s, const c_float * rho_vec) {
  232. c_int i;
  233. // Update internal rho_inv_vec
  234. for (i = 0; i < s->m; i++){
  235. s->rho_inv_vec[i] = 1. / rho_vec[i];
  236. }
  237. // Update KKT matrix with new rho_vec
  238. update_KKT_param2(s->KKT, s->rho_inv_vec, s->rhotoKKT, s->m);
  239. // Perform numerical factorization
  240. s->phase = PARDISO_NUMERIC;
  241. pardiso (s->pt, &(s->maxfct), &(s->mnum), &(s->mtype), &(s->phase),
  242. &(s->nKKT), s->KKT->x, s->KKT_p, s->KKT_i, &(s->idum), &(s->nrhs),
  243. s->iparm, &(s->msglvl), &(s->fdum), &(s->fdum), &(s->error));
  244. // Return exit flag
  245. return s->error;
  246. }