#pragma once

#include <Eigen/Core>
#include <algorithm>
#include <cassert>
#include <cstdio>
#include <vector>

namespace arc_spline
{

/*
================================================================================
================================================================================
                                band_matrix
================================================================================
================================================================================
*/

  // band matrix solver   带矩阵求解器
  class band_matrix
  {
  private:
    std::vector<std::vector<double>> m_upper; // upper band
    std::vector<std::vector<double>> m_lower; // lower band
  public:
    band_matrix(){};                        // constructor
    band_matrix(int dim, int n_u, int n_l); // constructor
    ~band_matrix(){};                       // destructor
    void resize(int dim, int n_u, int n_l); // init with dim,n_u,n_l
    int dim() const;                        // matrix dimension
    int num_upper() const
    {
      return m_upper.size() - 1;
    }
    int num_lower() const
    {
      return m_lower.size() - 1;
    }
    // access operator
    double &operator()(int i, int j);      // write
    double operator()(int i, int j) const; // read
    // we can store an additional diogonal (in m_lower)
    double &saved_diag(int i);
    double saved_diag(int i) const;
    void lu_decompose();
    std::vector<double> r_solve(const std::vector<double> &b) const;
    std::vector<double> l_solve(const std::vector<double> &b) const;
    std::vector<double> lu_solve(const std::vector<double> &b,
                                 bool is_lu_decomposed = false);
  };

  band_matrix::band_matrix(int dim, int n_u, int n_l)
  {
    resize(dim, n_u, n_l);
  }

  void band_matrix::resize(int dim, int n_u, int n_l)
  {
    assert(dim > 0);
    assert(n_u >= 0);
    assert(n_l >= 0);
    m_upper.resize(n_u + 1);
    m_lower.resize(n_l + 1);
    for (size_t i = 0; i < m_upper.size(); i++)
    {
      m_upper[i].resize(dim);
    }
    for (size_t i = 0; i < m_lower.size(); i++)
    {
      m_lower[i].resize(dim);
    }
  }

  int band_matrix::dim() const
  {
    if (m_upper.size() > 0)
    {
      return m_upper[0].size();
    }
    else
    {
      return 0;
    }
  }

  // defines the new operator (), so that we can access the elements
  // by A(i,j), index going from i=0,...,dim()-1
  // 定义新的运算符 (),以便我们可以访问元素
  // 通过 A(i,j),索引从 i=0,...,dim()-1 开始
  double &band_matrix::operator()(int i, int j)
  {
    int k = j - i; // what band is the entry
    assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
    assert((-num_lower() <= k) && (k <= num_upper()));
    // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
    if (k >= 0)
      return m_upper[k][i];
    else
      return m_lower[-k][i];
  }
  double band_matrix::operator()(int i, int j) const
  {
    int k = j - i; // what band is the entry
    assert((i >= 0) && (i < dim()) && (j >= 0) && (j < dim()));
    assert((-num_lower() <= k) && (k <= num_upper()));
    // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
    if (k >= 0)
      return m_upper[k][i];
    else
      return m_lower[-k][i];
  }

  // second diag (used in LU decomposition), saved in m_lower
  double band_matrix::saved_diag(int i) const
  {
    assert((i >= 0) && (i < dim()));
    return m_lower[0][i];
  }
  double &band_matrix::saved_diag(int i)
  {
    assert((i >= 0) && (i < dim()));
    return m_lower[0][i];
  }

  // LR-Decomposition of a band matrix
  void band_matrix::lu_decompose()
  {
    int i_max, j_max;
    int j_min;
    double x;

    // preconditioning
    // normalize column i so that a_ii=1
    for (int i = 0; i < this->dim(); i++)
    {
      assert(this->operator()(i, i) != 0.0);
      this->saved_diag(i) = 1.0 / this->operator()(i, i);
      j_min = std::max(0, i - this->num_lower());
      j_max = std::min(this->dim() - 1, i + this->num_upper());
      for (int j = j_min; j <= j_max; j++)
      {
        this->operator()(i, j) *= this->saved_diag(i);
      }
      this->operator()(i, i) = 1.0; // prevents rounding errors
    }

    // Gauss LR-Decomposition
    for (int k = 0; k < this->dim(); k++)
    {
      i_max = std::min(this->dim() - 1, k + this->num_lower()); // num_lower not a mistake!
      for (int i = k + 1; i <= i_max; i++)
      {
        assert(this->operator()(k, k) != 0.0);
        x = -this->operator()(i, k) / this->operator()(k, k);
        this->operator()(i, k) = -x; // assembly part of L
        j_max = std::min(this->dim() - 1, k + this->num_upper());
        for (int j = k + 1; j <= j_max; j++)
        {
          // assembly part of R
          this->operator()(i, j) = this->operator()(i, j) + x * this->operator()(k, j);
        }
      }
    }
  }

  // solves Ly=b
  std::vector<double> band_matrix::l_solve(const std::vector<double> &b) const
  {
    assert(this->dim() == (int)b.size());
    std::vector<double> x(this->dim());
    int j_start;
    double sum;
    for (int i = 0; i < this->dim(); i++)
    {
      sum = 0;
      j_start = std::max(0, i - this->num_lower());
      for (int j = j_start; j < i; j++)
        sum += this->operator()(i, j) * x[j];
      x[i] = (b[i] * this->saved_diag(i)) - sum;
    }
    return x;
  }

  // solves Rx=y
  std::vector<double> band_matrix::r_solve(const std::vector<double> &b) const
  {
    assert(this->dim() == (int)b.size());
    std::vector<double> x(this->dim());
    int j_stop;
    double sum;
    for (int i = this->dim() - 1; i >= 0; i--)
    {
      sum = 0;
      j_stop = std::min(this->dim() - 1, i + this->num_upper());
      for (int j = i + 1; j <= j_stop; j++)
        sum += this->operator()(i, j) * x[j];
      x[i] = (b[i] - sum) / this->operator()(i, i);
    }
    return x;
  }

  std::vector<double> band_matrix::lu_solve(const std::vector<double> &b,
                                            bool is_lu_decomposed)
  {
    assert(this->dim() == (int)b.size());
    std::vector<double> x, y;
    if (is_lu_decomposed == false)
    {
      this->lu_decompose();
    }
    y = this->l_solve(b);
    x = this->r_solve(y);
    return x;
  }

  /*
  ================================================================================
  ================================================================================
                                      spline
  ================================================================================
  ================================================================================
  */
  class spline
  {
  public:
    enum bd_type
    {
      first_deriv = 1,
      second_deriv = 2
    };

    std::vector<double> m_x, m_y; // x,y coordinates of points
    // interpolation parameters
    // f(x) = a*(x-x_i)^3 + b*(x-x_i)^2 + c*(x-x_i) + y_i
    std::vector<double> m_a, m_b, m_c; // spline coefficients  样条曲线系数
    double m_b0, m_c0;                 // for left extrapol
    bd_type m_left, m_right;           // 下面的构造函数里面默认的是 second_deriv
    double m_left_value, m_right_value;
    bool m_force_linear_extrapolation;

    // set default boundary condition to be zero curvature at both ends
    spline() : m_left(second_deriv), m_right(second_deriv), m_left_value(0.0), m_right_value(0.0), m_force_linear_extrapolation(false)
    {
      ;
    }

    // optional, but if called it has to come be before set_points()    可选,如果调用,必须要在 set_points() 函数前调用
    void set_boundary(bd_type left, double left_value,
                      bd_type right, double right_value,
                      bool force_linear_extrapolation = false);
    void set_points(const std::vector<double> &x,
                    const std::vector<double> &y, bool cubic_spline = true); // 这里的 cubic_spline 默认是 true, 调用的时候不给这个参数也是 ok 的
    // double operator() (double x) const;
    double operator()(double x, int dd = 0) const;
  };

  // 这个函数可以不调用,类的初始化的时候就有给默认值了
  void spline::set_boundary(spline::bd_type left, double left_value, spline::bd_type right, double right_value, bool force_linear_extrapolation)
  {
    assert(m_x.size() == 0); // set_points() must not have happened yet
    m_left = left;
    m_right = right;
    m_left_value = left_value;
    m_right_value = right_value;
    m_force_linear_extrapolation = force_linear_extrapolation;
  }

  void spline::set_points(const std::vector<double> &x, const std::vector<double> &y, bool cubic_spline) //
  {
    assert(x.size() == y.size());
    assert(x.size() > 2);
    m_x = x;
    m_y = y;
    int n = x.size();
    // TODO: maybe sort x and y, rather than returning an error
    for (int i = 0; i < n - 1; i++)
    {
      assert(m_x[i] < m_x[i + 1]);
    }

    if (cubic_spline == true)
    { // cubic spline interpolation
      // setting up the matrix and right hand side of the equation system
      // for the parameters b[]
      // 三次样条插值
      // 设定矩阵和方程组的右侧
      // 参数 b
      band_matrix A(n, 1, 1);     // 带状矩阵
      std::vector<double> rhs(n); // Ax = b 的 b 矩阵
      for (int i = 1; i < n - 1; i++)
      {
        A(i, i - 1) = 1.0 / 3.0 * (x[i] - x[i - 1]);
        A(i, i) = 2.0 / 3.0 * (x[i + 1] - x[i - 1]);
        A(i, i + 1) = 1.0 / 3.0 * (x[i + 1] - x[i]);
        rhs[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / (x[i] - x[i - 1]);
      }

      // boundary conditions  边界约束(默认是二阶约束)
      // 二阶约束:用户给定曲线两端的二阶导数值(加速度方向)
      // 一阶约束:用户给定曲线两端的一阶导数值(速度方向),也就是钳制三次样条!

      if (m_left == spline::second_deriv) // 二阶约束
      {
        // 2*b[0] = f''
        A(0, 0) = 2.0;
        A(0, 1) = 0.0;
        rhs[0] = m_left_value;
      }
      else if (m_left == spline::first_deriv) // 一阶约束
      {
        // c[0] = f', needs to be re-expressed in terms of b:
        // (2b[0]+b[1])(x[1]-x[0]) = 3 ((y[1]-y[0])/(x[1]-x[0]) - f')
        A(0, 0) = 2.0 * (x[1] - x[0]);
        A(0, 1) = 1.0 * (x[1] - x[0]);
        rhs[0] = 3.0 * ((y[1] - y[0]) / (x[1] - x[0]) - m_left_value);
      }
      else
      {
        assert(false); // 如果执行到这里,终止程序的执行,并且在标准错误流中输出错误信息
      }

      if (m_right == spline::second_deriv)
      {
        // 2*b[n-1] = f''
        A(n - 1, n - 1) = 2.0;
        A(n - 1, n - 2) = 0.0;
        rhs[n - 1] = m_right_value;
      }
      else if (m_right == spline::first_deriv)
      {
        // c[n-1] = f', needs to be re-expressed in terms of b:
        // (b[n-2]+2b[n-1])(x[n-1]-x[n-2])
        // = 3 (f' - (y[n-1]-y[n-2])/(x[n-1]-x[n-2]))
        A(n - 1, n - 1) = 2.0 * (x[n - 1] - x[n - 2]);
        A(n - 1, n - 2) = 1.0 * (x[n - 1] - x[n - 2]);
        rhs[n - 1] = 3.0 * (m_right_value - (y[n - 1] - y[n - 2]) / (x[n - 1] - x[n - 2]));
      }
      else
      {
        assert(false);
      }

      // solve the equation system to obtain the parameters b[]   求解方程组得到参数 b[]
      m_b = A.lu_solve(rhs);

      // calculate parameters a[] and c[] based on b[]  // 算出来具体的参数      
      m_a.resize(n);
      m_c.resize(n);
      for (int i = 0; i < n - 1; i++)
      {
        m_a[i] = 1.0 / 3.0 * (m_b[i + 1] - m_b[i]) / (x[i + 1] - x[i]);
        m_c[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - 1.0 / 3.0 * (2.0 * m_b[i] + m_b[i + 1]) * (x[i + 1] - x[i]);
      }
    }
    else
    { // linear interpolation     线性插值
      m_a.resize(n);
      m_b.resize(n);
      m_c.resize(n);
      for (int i = 0; i < n - 1; i++)
      {
        m_a[i] = 0.0;
        m_b[i] = 0.0;
        m_c[i] = (m_y[i + 1] - m_y[i]) / (m_x[i + 1] - m_x[i]);
      }
    }

    // for left extrapolation coefficients
    m_b0 = (m_force_linear_extrapolation == false) ? m_b[0] : 0.0;
    m_c0 = m_c[0];

    // for the right extrapolation coefficients
    // f_{n-1}(x) = b*(x-x_{n-1})^2 + c*(x-x_{n-1}) + y_{n-1}
    double h = x[n - 1] - x[n - 2];
    // m_b[n-1] is determined by the boundary condition
    m_a[n - 1] = 0.0;
    m_c[n - 1] = 3.0 * m_a[n - 2] * h * h + 2.0 * m_b[n - 2] * h + m_c[n - 2]; // = f'_{n-2}(x_{n-1})
    if (m_force_linear_extrapolation == true)
      m_b[n - 1] = 0.0;
  }

  double spline::operator()(double x, int dd) const // 运算符重载
  {
    assert(x >= m_x.front() && x <= m_x.back());
    assert(dd == 0 || dd == 1 || dd == 2);

    // find the closest point m_x[idx] < x, idx=0 even if x<m_x[0]
    // 找到最近的点 m_x[idx] < x, 如果 x<m_x[0]  则 idx=0
    std::vector<double>::const_iterator it;
    it = std::lower_bound(m_x.begin(), m_x.end(), x); // 使用 STL 中的 lower_bound 函数在 (std::vector<double>)m_x 中查找不小于 x 的第一个元素的位置,并将结果赋给迭代器 it
    int idx = std::max(int(it - m_x.begin()) - 1, 0); // 找到 小于 x 的最大索引 idx

    // 因为是分段三次样条曲线,这样做是为了找到对应的那段三次样条曲线
    double h = x - m_x[idx]; // 如果是 xt_ 对象,那么这里的 h 就是 时间 t 是小数位,因为每段三次样条曲线的 t 都是 0 <= t <= 1
    double interpol;

    // interpolation  插值(好像不叫什么插值?   就是根据样条曲线参数求出来对应 函数值、 一阶导数值、 二阶导数值)

    // m_a, m_b, m_c 分别是三次样条曲线的参数,方程应该是    S(t) = m_a * t^3 + m_b * t^2 + m_c * t + m_y ,  其中, 0 <= t <= 1
    if (dd == 0)
    {
      interpol = ((m_a[idx] * h + m_b[idx]) * h + m_c[idx]) * h + m_y[idx]; //
    }
    if (dd == 1)
    {
      interpol = (3 * m_a[idx] * h + 2 * m_b[idx]) * h + m_c[idx];
    }
    else if (dd == 2)
    {
      interpol = 6 * m_a[idx] * h + 2 * m_b[idx];
    }
    return interpol;
  }

  /*
  ================================================================================
  ================================================================================
                                    ArcSpline
  ================================================================================
  ================================================================================
  */

  class ArcSpline
  {
  private:
    double arcL_; // 整条三次样条曲线的弧长
    spline xs_, ys_;
    std::vector<double> sL_;
    std::vector<double> xL_;
    std::vector<double> yL_;

  public:
    ArcSpline(){};
    ~ArcSpline(){};
    void setWayPoints(const std::vector<double> &x_waypoints,
                      const std::vector<double> &y_waypoints);
    Eigen::Vector2d operator()(double s, int n = 0);
    double findS(const Eigen::Vector2d &p);
    inline double arcL()
    {
      return arcL_;
    };
  };

  void ArcSpline::setWayPoints(const std::vector<double> &x_waypoints,
                               const std::vector<double> &y_waypoints)
  {
    assert(x_waypoints.size() == y_waypoints.size()); // 一个在运行时进行条件检查的宏,如果为 false 程序显示错误信息并终止执行
    spline xt_, yt_;                                  // 存储曲线的 x 和 y 坐标的样条插值对象
    std::vector<double> t_list;
    std::vector<double> x_list = x_waypoints; // 样条曲线的控制点
    std::vector<double> y_list = y_waypoints;
    // add front to back -> a loop
    // x_list.push_back(x_list.front());
    // y_list.push_back(y_list.front());
    t_list.clear();   // 这行我自己加的
    for (size_t i = 0; i < x_list.size(); ++i)
    {
      t_list.push_back(i); // t_list = [0, 1, 2, 3, 4, 5, 6, ......]
    }
    xt_.set_points(t_list, x_list); // 根据 t 和 x 坐标,求解出对应的三次样条曲线参数(存储在 xt_.m_a, xt_.m_b, xt_.m_c 中)
    yt_.set_points(t_list, y_list);

    // calculate arc length approximately     近似计算弧长(差分法)
    double res = 1e-2; // 差分法的分辨率   差分分辨率
    sL_.clear();  
    xL_.clear();  // 这两行clear()  都是我加的
    yL_.clear();
    sL_.push_back(0);      // 从起点到当前点的累计弧长   是弧长吗?
    xL_.push_back(xt_(0)); // 从起点到当前点 x 坐标的累计弧长
    yL_.push_back(yt_(0)); // 从起点到当前点 y 坐标的累计弧长

    double last_s = 0;
    for (double t = res; t < t_list.back(); t += res)
    {
      double left_arc = res * std::sqrt(xt_(t - res, 1) * xt_(t - res, 1) + yt_(t - res, 1) * yt_(t - res, 1)); // 这里的括号用了运算符重载
      double right_arc = res * std::sqrt(xt_(t, 1) * xt_(t, 1) + yt_(t, 1) * yt_(t, 1));
      // xt_(t - res, 1)     xt_ 是三次样条曲线的对象,里面求解了分段三次样条曲线的相关参数
      // 括号用了运算符重载,xt_(t, 1)  求的就是 在横坐标 t 处的 1 阶导数值

      sL_.push_back(last_s + (left_arc + right_arc) / 2); // 这里确实是求弧长喔,

      last_s += (left_arc + right_arc) / 2;
      xL_.push_back(xt_(t)); // 括号运算符重载中,默认第二个参数是 0, 所以不输入的话就是计算函数值
      yL_.push_back(yt_(t));
    }
    xs_.set_points(sL_, xL_); // xL 关于 sL 的三次样条曲线    可以通过均匀的弧长改进对象的移动平滑性!
    ys_.set_points(sL_, yL_);
    arcL_ = sL_.back(); // 三次样条曲线的总弧长
  };

  Eigen::Vector2d ArcSpline::operator()(double s, int n /* =0 */) // n 默认值是 0
  {
    assert(n >= 0 && n <= 2);
    Eigen::Vector2d p(xs_(s, n), ys_(s, n)); // 返回在弧长 s 处  x 和 y 的 n 阶导数值
    return p;
  };

  inline double ArcSpline::findS(const Eigen::Vector2d &p) // 这个函数干嘛用的?
  {
    // TODO a more efficient and accurate method
    double min_dist = (operator()(sL_.front()) - p).norm();
    double min_s = sL_.front();                             // sL_.front() 返回的是 第一个元素,上面代码中,sL_.resize(0)  之后马上就是sL_.push_back(0),所以这里min_s 应该就直接是等于 0
    for (double s = sL_.front(); s < sL_.back(); s += 0.01) // s 从 0 开始遍历完整条曲线
    {
      double dist = (operator()(s) - p).norm();
      if (dist < min_dist)
      {
        min_s = s;
        min_dist = dist;
      }
    }
    return min_s;
  }
}