test_solve_linsys.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #include <stdio.h>
  2. #include "osqp.h"
  3. #include "cs.h"
  4. #include "util.h"
  5. #include "minunit.h"
  6. #include "lin_sys.h"
  7. #include "solve_linsys/data.h"
  8. static const char* test_solveKKT() {
  9. c_int m, exitflag = 0;
  10. c_float *rho_vec;
  11. LinSysSolver *s; // Private structure to form KKT factorization
  12. OSQPSettings *settings = (OSQPSettings *)c_malloc(sizeof(OSQPSettings)); // Settings
  13. solve_linsys_sols_data *data = generate_problem_solve_linsys_sols_data();
  14. // Settings
  15. settings->rho = data->test_solve_KKT_rho;
  16. settings->sigma = data->test_solve_KKT_sigma;
  17. // Set rho_vec
  18. m = data->test_solve_KKT_A->m;
  19. rho_vec = (c_float*) c_calloc(m, sizeof(c_float));
  20. vec_add_scalar(rho_vec, settings->rho, m);
  21. // Form and factorize KKT matrix
  22. exitflag = init_linsys_solver(&s, data->test_solve_KKT_Pu, data->test_solve_KKT_A,
  23. settings->sigma, rho_vec, LINSYS_SOLVER, 0);
  24. // Solve KKT x = b via LDL given factorization
  25. s->solve(s, data->test_solve_KKT_rhs);
  26. mu_assert(
  27. "Linear systems solve tests: error in forming and solving KKT system!",
  28. vec_norm_inf_diff(data->test_solve_KKT_rhs, data->test_solve_KKT_x,
  29. data->test_solve_KKT_m + data->test_solve_KKT_n) < TESTS_TOL);
  30. // Cleanup
  31. s->free(s);
  32. c_free(settings);
  33. c_free(rho_vec);
  34. clean_problem_solve_linsys_sols_data(data);
  35. return 0;
  36. }
  37. #ifdef ENABLE_MKL_PARDISO
  38. static char* test_solveKKT_pardiso() {
  39. c_int m, exitflag = 0;
  40. c_float *rho_vec;
  41. LinSysSolver *s; // Private structure to form KKT factorization
  42. OSQPSettings *settings = (OSQPSettings *)c_malloc(sizeof(OSQPSettings)); // Settings
  43. solve_linsys_sols_data *data = generate_problem_solve_linsys_sols_data();
  44. // Settings
  45. settings->rho = data->test_solve_KKT_rho;
  46. settings->sigma = data->test_solve_KKT_sigma;
  47. // Set rho_vec
  48. m = data->test_solve_KKT_A->m;
  49. rho_vec = c_calloc(m, sizeof(c_float));
  50. vec_add_scalar(rho_vec, settings->rho, m);
  51. // Load Pardiso shared library
  52. exitflag = load_linsys_solver(MKL_PARDISO_SOLVER);
  53. mu_assert("Linear system solve test: error in loading Pardiso shared library",
  54. exitflag == 0);
  55. // Form and factorize KKT matrix
  56. exitflag = init_linsys_solver(&s, data->test_solve_KKT_Pu, data->test_solve_KKT_A,
  57. settings->sigma, rho_vec, MKL_PARDISO_SOLVER, 0);
  58. // Solve KKT x = b via LDL given factorization
  59. s->solve(s, data->test_solve_KKT_rhs);
  60. mu_assert(
  61. "Linear systems solve tests: error in forming and solving KKT system with PARDISO!",
  62. vec_norm_inf_diff(data->test_solve_KKT_rhs, data->test_solve_KKT_x,
  63. data->test_solve_KKT_m + data->test_solve_KKT_n) < TESTS_TOL);
  64. // Cleanup
  65. s->free(s);
  66. c_free(settings);
  67. c_free(rho_vec);
  68. clean_problem_solve_linsys_sols_data(data);
  69. // Unload Pardiso shared library
  70. exitflag = unload_linsys_solver(MKL_PARDISO_SOLVER);
  71. return 0;
  72. }
  73. #endif
  74. static const char* test_solve_linsys()
  75. {
  76. mu_run_test(test_solveKKT);
  77. #ifdef ENABLE_MKL_PARDISO
  78. mu_run_test(test_solveKKT_pardiso);
  79. #endif
  80. return 0;
  81. }