arc_spline.hpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. #pragma once
  2. #include <Eigen/Core>
  3. #include <algorithm>
  4. #include <cassert>
  5. #include <cstdio>
  6. #include <vector>
  7. namespace arc_spline
  8. {
  9. /*
  10. ================================================================================
  11. ================================================================================
  12. band_matrix
  13. ================================================================================
  14. ================================================================================
  15. */
  16. // band matrix solver 带矩阵求解器
  17. class band_matrix
  18. {
  19. private:
  20. std::vector<std::vector<double>> m_upper; // upper band
  21. std::vector<std::vector<double>> m_lower; // lower band
  22. public:
  23. band_matrix(){}; // constructor
  24. band_matrix(int dim, int n_u, int n_l); // constructor
  25. ~band_matrix(){}; // destructor
  26. void resize(int dim, int n_u, int n_l); // init with dim,n_u,n_l
  27. int dim() const; // matrix dimension
  28. int num_upper() const
  29. {
  30. return m_upper.size() - 1;
  31. }
  32. int num_lower() const
  33. {
  34. return m_lower.size() - 1;
  35. }
  36. // access operator
  37. double &operator()(int i, int j); // write
  38. double operator()(int i, int j) const; // read
  39. // we can store an additional diogonal (in m_lower)
  40. double &saved_diag(int i);
  41. double saved_diag(int i) const;
  42. void lu_decompose();
  43. std::vector<double> r_solve(const std::vector<double> &b) const;
  44. std::vector<double> l_solve(const std::vector<double> &b) const;
  45. std::vector<double> lu_solve(const std::vector<double> &b,
  46. bool is_lu_decomposed = false);
  47. };
  48. band_matrix::band_matrix(int dim, int n_u, int n_l)
  49. {
  50. resize(dim, n_u, n_l);
  51. }
  52. void band_matrix::resize(int dim, int n_u, int n_l)
  53. {
  54. assert(dim > 0);
  55. assert(n_u >= 0);
  56. assert(n_l >= 0);
  57. m_upper.resize(n_u + 1);
  58. m_lower.resize(n_l + 1);
  59. for (size_t i = 0; i < m_upper.size(); i++)
  60. {
  61. m_upper[i].resize(dim);
  62. }
  63. for (size_t i = 0; i < m_lower.size(); i++)
  64. {
  65. m_lower[i].resize(dim);
  66. }
  67. }
  68. int band_matrix::dim() const
  69. {
  70. if (m_upper.size() > 0)
  71. {
  72. return m_upper[0].size();
  73. }
  74. else
  75. {
  76. return 0;
  77. }
  78. }
  79. // defines the new operator (), so that we can access the elements
  80. // by A(i,j), index going from i=0,...,dim()-1
  81. // 定义新的运算符 (),以便我们可以访问元素
  82. // 通过 A(i,j),索引从 i=0,...,dim()-1 开始
  83. double &band_matrix::operator()(int i, int j)
  84. {
  85. int k = j - i; // what band is the entry
  86. assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
  87. assert((-num_lower() <= k) && (k <= num_upper()));
  88. // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
  89. if (k >= 0)
  90. return m_upper[k][i];
  91. else
  92. return m_lower[-k][i];
  93. }
  94. double band_matrix::operator()(int i, int j) const
  95. {
  96. int k = j - i; // what band is the entry
  97. assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
  98. assert((-num_lower() <= k) && (k <= num_upper()));
  99. // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
  100. if (k >= 0)
  101. return m_upper[k][i];
  102. else
  103. return m_lower[-k][i];
  104. }
  105. // second diag (used in LU decomposition), saved in m_lower
  106. double band_matrix::saved_diag(int i) const
  107. {
  108. assert((i >= 0) && (i < dim()));
  109. return m_lower[0][i];
  110. }
  111. double &band_matrix::saved_diag(int i)
  112. {
  113. assert((i >= 0) && (i < dim()));
  114. return m_lower[0][i];
  115. }
  116. // LR-Decomposition of a band matrix
  117. void band_matrix::lu_decompose()
  118. {
  119. int i_max, j_max;
  120. int j_min;
  121. double x;
  122. // preconditioning
  123. // normalize column i so that a_ii=1
  124. for (int i = 0; i < this->dim(); i++)
  125. {
  126. assert(this->operator()(i, i) != 0.0);
  127. this->saved_diag(i) = 1.0 / this->operator()(i, i);
  128. j_min = std::max(0, i - this->num_lower());
  129. j_max = std::min(this->dim() - 1, i + this->num_upper());
  130. for (int j = j_min; j <= j_max; j++)
  131. {
  132. this->operator()(i, j) *= this->saved_diag(i);
  133. }
  134. this->operator()(i, i) = 1.0; // prevents rounding errors
  135. }
  136. // Gauss LR-Decomposition
  137. for (int k = 0; k < this->dim(); k++)
  138. {
  139. i_max = std::min(this->dim() - 1, k + this->num_lower()); // num_lower not a mistake!
  140. for (int i = k + 1; i <= i_max; i++)
  141. {
  142. assert(this->operator()(k, k) != 0.0);
  143. x = -this->operator()(i, k) / this->operator()(k, k);
  144. this->operator()(i, k) = -x; // assembly part of L
  145. j_max = std::min(this->dim() - 1, k + this->num_upper());
  146. for (int j = k + 1; j <= j_max; j++)
  147. {
  148. // assembly part of R
  149. this->operator()(i, j) = this->operator()(i, j) + x * this->operator()(k, j);
  150. }
  151. }
  152. }
  153. }
  154. // solves Ly=b
  155. std::vector<double> band_matrix::l_solve(const std::vector<double> &b) const
  156. {
  157. assert(this->dim() == (int)b.size());
  158. std::vector<double> x(this->dim());
  159. int j_start;
  160. double sum;
  161. for (int i = 0; i < this->dim(); i++)
  162. {
  163. sum = 0;
  164. j_start = std::max(0, i - this->num_lower());
  165. for (int j = j_start; j < i; j++)
  166. sum += this->operator()(i, j) * x[j];
  167. x[i] = (b[i] * this->saved_diag(i)) - sum;
  168. }
  169. return x;
  170. }
  171. // solves Rx=y
  172. std::vector<double> band_matrix::r_solve(const std::vector<double> &b) const
  173. {
  174. assert(this->dim() == (int)b.size());
  175. std::vector<double> x(this->dim());
  176. int j_stop;
  177. double sum;
  178. for (int i = this->dim() - 1; i >= 0; i--)
  179. {
  180. sum = 0;
  181. j_stop = std::min(this->dim() - 1, i + this->num_upper());
  182. for (int j = i + 1; j <= j_stop; j++)
  183. sum += this->operator()(i, j) * x[j];
  184. x[i] = (b[i] - sum) / this->operator()(i, i);
  185. }
  186. return x;
  187. }
  188. std::vector<double> band_matrix::lu_solve(const std::vector<double> &b,
  189. bool is_lu_decomposed)
  190. {
  191. assert(this->dim() == (int)b.size());
  192. std::vector<double> x, y;
  193. if (is_lu_decomposed == false)
  194. {
  195. this->lu_decompose();
  196. }
  197. y = this->l_solve(b);
  198. x = this->r_solve(y);
  199. return x;
  200. }
  201. /*
  202. ================================================================================
  203. ================================================================================
  204. spline
  205. ================================================================================
  206. ================================================================================
  207. */
  208. class spline
  209. {
  210. public:
  211. enum bd_type
  212. {
  213. first_deriv = 1,
  214. second_deriv = 2
  215. };
  216. std::vector<double> m_x, m_y; // x,y coordinates of points
  217. // interpolation parameters
  218. // f(x) = a*(x-x_i)^3 + b*(x-x_i)^2 + c*(x-x_i) + y_i
  219. std::vector<double> m_a, m_b, m_c; // spline coefficients 样条曲线系数
  220. double m_b0, m_c0; // for left extrapol
  221. bd_type m_left, m_right; // 下面的构造函数里面默认的是 second_deriv
  222. double m_left_value, m_right_value;
  223. bool m_force_linear_extrapolation;
  224. // set default boundary condition to be zero curvature at both ends
  225. spline() : m_left(second_deriv), m_right(second_deriv), m_left_value(0.0), m_right_value(0.0), m_force_linear_extrapolation(false)
  226. {
  227. ;
  228. }
  229. // optional, but if called it has to come be before set_points() 可选,如果调用,必须要在 set_points() 函数前调用
  230. void set_boundary(bd_type left, double left_value,
  231. bd_type right, double right_value,
  232. bool force_linear_extrapolation = false);
  233. void set_points(const std::vector<double> &x,
  234. const std::vector<double> &y, bool cubic_spline = true); // 这里的 cubic_spline 默认是 true, 调用的时候不给这个参数也是 ok 的
  235. // double operator() (double x) const;
  236. double operator()(double x, int dd = 0) const;
  237. };
  238. // 这个函数可以不调用,类的初始化的时候就有给默认值了
  239. void spline::set_boundary(spline::bd_type left, double left_value, spline::bd_type right, double right_value, bool force_linear_extrapolation)
  240. {
  241. assert(m_x.size() == 0); // set_points() must not have happened yet
  242. m_left = left;
  243. m_right = right;
  244. m_left_value = left_value;
  245. m_right_value = right_value;
  246. m_force_linear_extrapolation = force_linear_extrapolation;
  247. }
  248. void spline::set_points(const std::vector<double> &x, const std::vector<double> &y, bool cubic_spline) //
  249. {
  250. assert(x.size() == y.size());
  251. assert(x.size() > 2);
  252. m_x = x;
  253. m_y = y;
  254. int n = x.size();
  255. // TODO: maybe sort x and y, rather than returning an error
  256. for (int i = 0; i < n - 1; i++)
  257. {
  258. assert(m_x[i] < m_x[i + 1]);
  259. }
  260. if (cubic_spline == true)
  261. { // cubic spline interpolation
  262. // setting up the matrix and right hand side of the equation system
  263. // for the parameters b[]
  264. // 三次样条插值
  265. // 设定矩阵和方程组的右侧
  266. // 参数 b
  267. band_matrix A(n, 1, 1); // 带状矩阵
  268. std::vector<double> rhs(n); // Ax = b 的 b 矩阵
  269. for (int i = 1; i < n - 1; i++)
  270. {
  271. A(i, i - 1) = 1.0 / 3.0 * (x[i] - x[i - 1]);
  272. A(i, i) = 2.0 / 3.0 * (x[i + 1] - x[i - 1]);
  273. A(i, i + 1) = 1.0 / 3.0 * (x[i + 1] - x[i]);
  274. rhs[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / (x[i] - x[i - 1]);
  275. }
  276. // boundary conditions 边界约束(默认是二阶约束)
  277. // 二阶约束:用户给定曲线两端的二阶导数值(加速度方向)
  278. // 一阶约束:用户给定曲线两端的一阶导数值(速度方向),也就是钳制三次样条!
  279. if (m_left == spline::second_deriv) // 二阶约束
  280. {
  281. // 2*b[0] = f''
  282. A(0, 0) = 2.0;
  283. A(0, 1) = 0.0;
  284. rhs[0] = m_left_value;
  285. }
  286. else if (m_left == spline::first_deriv) // 一阶约束
  287. {
  288. // c[0] = f', needs to be re-expressed in terms of b:
  289. // (2b[0]+b[1])(x[1]-x[0]) = 3 ((y[1]-y[0])/(x[1]-x[0]) - f')
  290. A(0, 0) = 2.0 * (x[1] - x[0]);
  291. A(0, 1) = 1.0 * (x[1] - x[0]);
  292. rhs[0] = 3.0 * ((y[1] - y[0]) / (x[1] - x[0]) - m_left_value);
  293. }
  294. else
  295. {
  296. assert(false); // 如果执行到这里,终止程序的执行,并且在标准错误流中输出错误信息
  297. }
  298. if (m_right == spline::second_deriv)
  299. {
  300. // 2*b[n-1] = f''
  301. A(n - 1, n - 1) = 2.0;
  302. A(n - 1, n - 2) = 0.0;
  303. rhs[n - 1] = m_right_value;
  304. }
  305. else if (m_right == spline::first_deriv)
  306. {
  307. // c[n-1] = f', needs to be re-expressed in terms of b:
  308. // (b[n-2]+2b[n-1])(x[n-1]-x[n-2])
  309. // = 3 (f' - (y[n-1]-y[n-2])/(x[n-1]-x[n-2]))
  310. A(n - 1, n - 1) = 2.0 * (x[n - 1] - x[n - 2]);
  311. A(n - 1, n - 2) = 1.0 * (x[n - 1] - x[n - 2]);
  312. rhs[n - 1] = 3.0 * (m_right_value - (y[n - 1] - y[n - 2]) / (x[n - 1] - x[n - 2]));
  313. }
  314. else
  315. {
  316. assert(false);
  317. }
  318. // solve the equation system to obtain the parameters b[] 求解方程组得到参数 b[]
  319. m_b = A.lu_solve(rhs);
  320. // calculate parameters a[] and c[] based on b[] // 算出来具体的参数
  321. m_a.resize(n);
  322. m_c.resize(n);
  323. for (int i = 0; i < n - 1; i++)
  324. {
  325. m_a[i] = 1.0 / 3.0 * (m_b[i + 1] - m_b[i]) / (x[i + 1] - x[i]);
  326. m_c[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - 1.0 / 3.0 * (2.0 * m_b[i] + m_b[i + 1]) * (x[i + 1] - x[i]);
  327. }
  328. }
  329. else
  330. { // linear interpolation 线性插值
  331. m_a.resize(n);
  332. m_b.resize(n);
  333. m_c.resize(n);
  334. for (int i = 0; i < n - 1; i++)
  335. {
  336. m_a[i] = 0.0;
  337. m_b[i] = 0.0;
  338. m_c[i] = (m_y[i + 1] - m_y[i]) / (m_x[i + 1] - m_x[i]);
  339. }
  340. }
  341. // for left extrapolation coefficients
  342. m_b0 = (m_force_linear_extrapolation == false) ? m_b[0] : 0.0;
  343. m_c0 = m_c[0];
  344. // for the right extrapolation coefficients
  345. // f_{n-1}(x) = b*(x-x_{n-1})^2 + c*(x-x_{n-1}) + y_{n-1}
  346. double h = x[n - 1] - x[n - 2];
  347. // m_b[n-1] is determined by the boundary condition
  348. m_a[n - 1] = 0.0;
  349. m_c[n - 1] = 3.0 * m_a[n - 2] * h * h + 2.0 * m_b[n - 2] * h + m_c[n - 2]; // = f'_{n-2}(x_{n-1})
  350. if (m_force_linear_extrapolation == true)
  351. m_b[n - 1] = 0.0;
  352. }
  353. double spline::operator()(double x, int dd) const // 运算符重载
  354. {
  355. assert(x >= m_x.front() && x <= m_x.back());
  356. assert(dd == 0 || dd == 1 || dd == 2);
  357. // find the closest point m_x[idx] < x, idx=0 even if x<m_x[0]
  358. // 找到最近的点 m_x[idx] < x, 如果 x<m_x[0] 则 idx=0
  359. std::vector<double>::const_iterator it;
  360. it = std::lower_bound(m_x.begin(), m_x.end(), x); // 使用 STL 中的 lower_bound 函数在 (std::vector<double>)m_x 中查找不小于 x 的第一个元素的位置,并将结果赋给迭代器 it
  361. int idx = std::max(int(it - m_x.begin()) - 1, 0); // 找到 小于 x 的最大索引 idx
  362. // 因为是分段三次样条曲线,这样做是为了找到对应的那段三次样条曲线
  363. double h = x - m_x[idx]; // 如果是 xt_ 对象,那么这里的 h 就是 时间 t 是小数位,因为每段三次样条曲线的 t 都是 0 <= t <= 1
  364. double interpol;
  365. // interpolation 插值(好像不叫什么插值? 就是根据样条曲线参数求出来对应 函数值、 一阶导数值、 二阶导数值)
  366. // m_a, m_b, m_c 分别是三次样条曲线的参数,方程应该是 S(t) = m_a * t^3 + m_b * t^2 + m_c * t + m_y , 其中, 0 <= t <= 1
  367. if (dd == 0)
  368. {
  369. interpol = ((m_a[idx] * h + m_b[idx]) * h + m_c[idx]) * h + m_y[idx]; //
  370. }
  371. if (dd == 1)
  372. {
  373. interpol = (3 * m_a[idx] * h + 2 * m_b[idx]) * h + m_c[idx];
  374. }
  375. else if (dd == 2)
  376. {
  377. interpol = 6 * m_a[idx] * h + 2 * m_b[idx];
  378. }
  379. return interpol;
  380. }
  381. /*
  382. ================================================================================
  383. ================================================================================
  384. ArcSpline
  385. ================================================================================
  386. ================================================================================
  387. */
  388. class ArcSpline
  389. {
  390. private:
  391. double arcL_; // 整条三次样条曲线的弧长
  392. spline xs_, ys_;
  393. std::vector<double> sL_;
  394. std::vector<double> xL_;
  395. std::vector<double> yL_;
  396. public:
  397. ArcSpline(){};
  398. ~ArcSpline(){};
  399. void setWayPoints(const std::vector<double> &x_waypoints,
  400. const std::vector<double> &y_waypoints);
  401. Eigen::Vector2d operator()(double s, int n = 0);
  402. double findS(const Eigen::Vector2d &p);
  403. inline double arcL()
  404. {
  405. return arcL_;
  406. };
  407. };
  408. void ArcSpline::setWayPoints(const std::vector<double> &x_waypoints,
  409. const std::vector<double> &y_waypoints)
  410. {
  411. assert(x_waypoints.size() == y_waypoints.size()); // 一个在运行时进行条件检查的宏,如果为 false 程序显示错误信息并终止执行
  412. spline xt_, yt_; // 存储曲线的 x 和 y 坐标的样条插值对象
  413. std::vector<double> t_list;
  414. std::vector<double> x_list = x_waypoints; // 样条曲线的控制点
  415. std::vector<double> y_list = y_waypoints;
  416. // add front to back -> a loop
  417. // x_list.push_back(x_list.front());
  418. // y_list.push_back(y_list.front());
  419. t_list.clear(); // 这行我自己加的
  420. for (size_t i = 0; i < x_list.size(); ++i)
  421. {
  422. t_list.push_back(i); // t_list = [0, 1, 2, 3, 4, 5, 6, ......]
  423. }
  424. xt_.set_points(t_list, x_list); // 根据 t 和 x 坐标,求解出对应的三次样条曲线参数(存储在 xt_.m_a, xt_.m_b, xt_.m_c 中)
  425. yt_.set_points(t_list, y_list);
  426. // calculate arc length approximately 近似计算弧长(差分法)
  427. double res = 1e-2; // 差分法的分辨率 差分分辨率
  428. sL_.clear();
  429. xL_.clear(); // 这两行clear() 都是我加的
  430. yL_.clear();
  431. sL_.push_back(0); // 从起点到当前点的累计弧长 是弧长吗?
  432. xL_.push_back(xt_(0)); // 从起点到当前点 x 坐标的累计弧长
  433. yL_.push_back(yt_(0)); // 从起点到当前点 y 坐标的累计弧长
  434. double last_s = 0;
  435. for (double t = res; t < t_list.back(); t += res)
  436. {
  437. double left_arc = res * std::sqrt(xt_(t - res, 1) * xt_(t - res, 1) + yt_(t - res, 1) * yt_(t - res, 1)); // 这里的括号用了运算符重载
  438. double right_arc = res * std::sqrt(xt_(t, 1) * xt_(t, 1) + yt_(t, 1) * yt_(t, 1));
  439. // xt_(t - res, 1) xt_ 是三次样条曲线的对象,里面求解了分段三次样条曲线的相关参数
  440. // 括号用了运算符重载,xt_(t, 1) 求的就是 在横坐标 t 处的 1 阶导数值
  441. sL_.push_back(last_s + (left_arc + right_arc) / 2); // 这里确实是求弧长喔,
  442. last_s += (left_arc + right_arc) / 2;
  443. xL_.push_back(xt_(t)); // 括号运算符重载中,默认第二个参数是 0, 所以不输入的话就是计算函数值
  444. yL_.push_back(yt_(t));
  445. }
  446. xs_.set_points(sL_, xL_); // xL 关于 sL 的三次样条曲线 可以通过均匀的弧长改进对象的移动平滑性!
  447. ys_.set_points(sL_, yL_);
  448. arcL_ = sL_.back(); // 三次样条曲线的总弧长
  449. };
  450. Eigen::Vector2d ArcSpline::operator()(double s, int n /* =0 */) // n 默认值是 0
  451. {
  452. assert(n >= 0 && n <= 2);
  453. Eigen::Vector2d p(xs_(s, n), ys_(s, n)); // 返回在弧长 s 处 x 和 y 的 n 阶导数值
  454. return p;
  455. };
  456. inline double ArcSpline::findS(const Eigen::Vector2d &p) // 这个函数干嘛用的?
  457. {
  458. // TODO a more efficient and accurate method
  459. double min_dist = (operator()(sL_.front()) - p).norm();
  460. double min_s = sL_.front(); // sL_.front() 返回的是 第一个元素,上面代码中,sL_.resize(0) 之后马上就是sL_.push_back(0),所以这里min_s 应该就直接是等于 0
  461. for (double s = sL_.front(); s < sL_.back(); s += 0.01) // s 从 0 开始遍历完整条曲线
  462. {
  463. double dist = (operator()(s) - p).norm();
  464. if (dist < min_dist)
  465. {
  466. min_s = s;
  467. min_dist = dist;
  468. }
  469. }
  470. return min_s;
  471. }
  472. }