#include <scheme/reconstruction_utils/ElementIntegralReconstructionMatrixBuilder.hpp>

#include <analysis/GaussLegendreQuadratureDescriptor.hpp>
#include <analysis/GaussQuadratureDescriptor.hpp>
#include <geometry/CubeTransformation.hpp>
#include <geometry/LineTransformation.hpp>
#include <geometry/PrismTransformation.hpp>
#include <geometry/PyramidTransformation.hpp>
#include <geometry/SquareTransformation.hpp>
#include <geometry/SymmetryUtils.hpp>
#include <geometry/TetrahedronTransformation.hpp>
#include <geometry/TriangleTransformation.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshDataManager.hpp>
#include <scheme/DiscreteFunctionDPk.hpp>

template <MeshConcept MeshTypeT>
template <typename ConformTransformationT>
void
PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<MeshTypeT>::_computeEjkMean(
  const QuadratureFormula<MeshType::Dimension>& quadrature,
  const ConformTransformationT& T,
  const Rd& Xj,
  const double Vi,
  SmallArray<double>& mean_of_ejk) noexcept(NO_ASSERT)
{
  mean_of_ejk.fill(0);

  for (size_t i_q = 0; i_q < quadrature.numberOfPoints(); ++i_q) {
    const double wq = quadrature.weight(i_q);
    const Rd& xi_q  = quadrature.point(i_q);

    const Rd X_Xj = T(xi_q) - Xj;

    if constexpr (MeshType::Dimension == 1) {
      const double detT = T.jacobianDeterminant();

      const double x_xj = X_Xj[0];

      {
        m_wq_detJ_ek[0] = wq * detT;
        for (size_t k = 1; k <= m_polynomial_degree; ++k) {
          m_wq_detJ_ek[k] = x_xj * m_wq_detJ_ek[k - 1];
        }
      }

    } else if constexpr (MeshType::Dimension == 2) {
      const double detT = [&] {
        if constexpr (std::is_same_v<TriangleTransformation<2>, std::decay_t<decltype(T)>>) {
          return T.jacobianDeterminant();
        } else {
          return T.jacobianDeterminant(xi_q);
        }
      }();

      const double x_xj = X_Xj[0];
      const double y_yj = X_Xj[1];

      {
        size_t k          = 0;
        m_wq_detJ_ek[k++] = wq * detT;
        for (; k <= m_polynomial_degree; ++k) {
          m_wq_detJ_ek[k] = x_xj * m_wq_detJ_ek[k - 1];
        }

        for (size_t i_y = 1; i_y <= m_polynomial_degree; ++i_y) {
          const size_t begin_i_y_1 = m_y_row_index[i_y - 1];
          for (size_t l = 0; l <= m_polynomial_degree - i_y; ++l, ++k) {
            m_wq_detJ_ek[k] = y_yj * m_wq_detJ_ek[begin_i_y_1 + l];
          }
        }
      }

    } else if constexpr (MeshType::Dimension == 3) {
      static_assert(MeshType::Dimension == 3);

      const double detT = [&] {
        if constexpr (std::is_same_v<TetrahedronTransformation, std::decay_t<decltype(T)>>) {
          return T.jacobianDeterminant();
        } else {
          return T.jacobianDeterminant(xi_q);
        }
      }();

      const double x_xj = X_Xj[0];
      const double y_yj = X_Xj[1];
      const double z_zj = X_Xj[2];

      {
        size_t k          = 0;
        m_wq_detJ_ek[k++] = wq * detT;
        for (; k <= m_polynomial_degree; ++k) {
          m_wq_detJ_ek[k] = x_xj * m_wq_detJ_ek[k - 1];
        }

        for (size_t i_y = 1; i_y <= m_polynomial_degree; ++i_y) {
          const size_t begin_i_y_1 = m_yz_row_index[i_y - 1];
          const size_t nb_monoms   = m_yz_row_size[i_y];
          for (size_t l = 0; l < nb_monoms; ++l, ++k) {
            m_wq_detJ_ek[k] = y_yj * m_wq_detJ_ek[begin_i_y_1 + l];
          }
        }

        for (size_t i_z = 1; i_z <= m_polynomial_degree; ++i_z) {
          const size_t nb_y      = m_yz_row_size[m_z_triangle_index[i_z]];
          const size_t index_z   = m_z_triangle_index[i_z];
          const size_t index_z_1 = m_z_triangle_index[i_z - 1];
          for (size_t i_y = 0; i_y < nb_y; ++i_y) {
            const size_t begin_i_yz_1 = m_yz_row_index[index_z_1 + i_y];
            const size_t nb_monoms    = m_yz_row_size[index_z + i_y];
            for (size_t l = 0; l < nb_monoms; ++l, ++k) {
              m_wq_detJ_ek[k] = z_zj * m_wq_detJ_ek[begin_i_yz_1 + l];
            }
          }
        }
      }
    }

    for (size_t k = 1; k < m_basis_dimension; ++k) {
      mean_of_ejk[k - 1] += m_wq_detJ_ek[k];
    }
  }

  const double inv_Vi = 1. / Vi;
  for (size_t k = 0; k < mean_of_ejk.size(); ++k) {
    mean_of_ejk[k] *= inv_Vi;
  }
}

template <MeshConcept MeshTypeT>
void
PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<MeshTypeT>::_computeEjkMean(
  const Rd& Xj,
  const CellId& cell_i_id,
  SmallArray<double>& mean_of_ejk)
{
  const CellType cell_type = m_cell_type[cell_i_id];
  const auto node_list     = m_cell_to_node_matrix[cell_i_id];
  const double Vi          = m_Vj[cell_i_id];

  if constexpr (MeshType::Dimension == 1) {
    if (m_cell_type[cell_i_id] == CellType::Line) {
      const LineTransformation<1> T{m_xr[node_list[0]], m_xr[node_list[1]]};

      const auto& quadrature =
        QuadratureManager::instance().getLineFormula(GaussLegendreQuadratureDescriptor{m_polynomial_degree});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);

    } else {
      throw NotImplementedError("unexpected cell type: " + std::string{name(cell_type)});
    }
  } else if constexpr (MeshType::Dimension == 2) {
    switch (cell_type) {
    case CellType::Triangle: {
      const TriangleTransformation<2> T{m_xr[node_list[0]], m_xr[node_list[1]], m_xr[node_list[2]]};
      const auto& quadrature =
        QuadratureManager::instance().getTriangleFormula(GaussQuadratureDescriptor{m_polynomial_degree});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    case CellType::Quadrangle: {
      const SquareTransformation<2> T{m_xr[node_list[0]], m_xr[node_list[1]], m_xr[node_list[2]], m_xr[node_list[3]]};
      const auto& quadrature =
        QuadratureManager::instance().getSquareFormula(GaussLegendreQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    default: {
      throw NotImplementedError("unexpected cell type: " + std::string{name(cell_type)});
    }
    }
  } else {
    static_assert(MeshType::Dimension == 3);

    switch (cell_type) {
    case CellType::Tetrahedron: {
      const TetrahedronTransformation T{m_xr[node_list[0]], m_xr[node_list[1]], m_xr[node_list[2]], m_xr[node_list[3]]};

      const auto& quadrature =
        QuadratureManager::instance().getTetrahedronFormula(GaussQuadratureDescriptor{m_polynomial_degree});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);

      break;
    }
    case CellType::Prism: {
      const PrismTransformation T{m_xr[node_list[0]], m_xr[node_list[1]], m_xr[node_list[2]],   //
                                  m_xr[node_list[3]], m_xr[node_list[4]], m_xr[node_list[5]]};

      const auto& quadrature =
        QuadratureManager::instance().getPrismFormula(GaussQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);

      break;
    }
    case CellType::Pyramid: {
      const PyramidTransformation T{m_xr[node_list[0]], m_xr[node_list[1]], m_xr[node_list[2]], m_xr[node_list[3]],
                                    m_xr[node_list[4]]};

      const auto& quadrature =
        QuadratureManager::instance().getPyramidFormula(GaussQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    case CellType::Hexahedron: {
      const CubeTransformation T{m_xr[node_list[0]], m_xr[node_list[1]], m_xr[node_list[2]], m_xr[node_list[3]],
                                 m_xr[node_list[4]], m_xr[node_list[5]], m_xr[node_list[6]], m_xr[node_list[7]]};

      const auto& quadrature =
        QuadratureManager::instance().getCubeFormula(GaussLegendreQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    default: {
      throw NotImplementedError("unexpected cell type: " + std::string{name(cell_type)});
    }
    }
  }
}

template <MeshConcept MeshTypeT>
void
PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<MeshTypeT>::_computeEjkMeanInSymmetricCell(
  const Rd& origin,
  const Rd& normal,
  const Rd& Xj,
  const CellId& cell_i_id,
  SmallArray<double>& mean_of_ejk)
{
  if constexpr (MeshType::Dimension == 1) {
    auto node_list           = m_cell_to_node_matrix[cell_i_id];
    const CellType cell_type = m_cell_type[cell_i_id];
    const double Vi          = m_Vj[cell_i_id];

    if (cell_type == CellType::Line) {
      const auto x0 = symmetrize_coordinates(origin, normal, m_xr[node_list[1]]);
      const auto x1 = symmetrize_coordinates(origin, normal, m_xr[node_list[0]]);

      const LineTransformation<1> T{x0, x1};

      const auto& quadrature =
        QuadratureManager::instance().getLineFormula(GaussLegendreQuadratureDescriptor{m_polynomial_degree});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);

    } else {
      throw NotImplementedError("unexpected cell type: " + std::string{name(cell_type)});
    }
  } else if constexpr (MeshType::Dimension == 2) {
    auto node_list           = m_cell_to_node_matrix[cell_i_id];
    const CellType cell_type = m_cell_type[cell_i_id];
    const double Vi          = m_Vj[cell_i_id];

    switch (cell_type) {
    case CellType::Triangle: {
      const auto x0 = symmetrize_coordinates(origin, normal, m_xr[node_list[2]]);
      const auto x1 = symmetrize_coordinates(origin, normal, m_xr[node_list[1]]);
      const auto x2 = symmetrize_coordinates(origin, normal, m_xr[node_list[0]]);

      const TriangleTransformation<2> T{x0, x1, x2};
      const auto& quadrature =
        QuadratureManager::instance().getTriangleFormula(GaussQuadratureDescriptor{m_polynomial_degree});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    case CellType::Quadrangle: {
      const auto x0 = symmetrize_coordinates(origin, normal, m_xr[node_list[3]]);
      const auto x1 = symmetrize_coordinates(origin, normal, m_xr[node_list[2]]);
      const auto x2 = symmetrize_coordinates(origin, normal, m_xr[node_list[1]]);
      const auto x3 = symmetrize_coordinates(origin, normal, m_xr[node_list[0]]);

      const SquareTransformation<2> T{x0, x1, x2, x3};
      const auto& quadrature =
        QuadratureManager::instance().getSquareFormula(GaussLegendreQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    default: {
      throw NotImplementedError("unexpected cell type: " + std::string{name(cell_type)});
    }
    }
  } else {
    static_assert(MeshType::Dimension == 3);
    auto node_list           = m_cell_to_node_matrix[cell_i_id];
    const CellType cell_type = m_cell_type[cell_i_id];
    const double Vi          = m_Vj[cell_i_id];
    switch (cell_type) {
    case CellType::Tetrahedron: {
      const auto x0 = symmetrize_coordinates(origin, normal, m_xr[node_list[1]]);
      const auto x1 = symmetrize_coordinates(origin, normal, m_xr[node_list[0]]);
      const auto x2 = symmetrize_coordinates(origin, normal, m_xr[node_list[2]]);
      const auto x3 = symmetrize_coordinates(origin, normal, m_xr[node_list[3]]);

      const TetrahedronTransformation T{x0, x1, x2, x3};

      const auto& quadrature =
        QuadratureManager::instance().getTetrahedronFormula(GaussQuadratureDescriptor{m_polynomial_degree});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    case CellType::Prism: {
      const auto x0 = symmetrize_coordinates(origin, normal, m_xr[node_list[1]]);
      const auto x1 = symmetrize_coordinates(origin, normal, m_xr[node_list[0]]);
      const auto x2 = symmetrize_coordinates(origin, normal, m_xr[node_list[2]]);
      const auto x3 = symmetrize_coordinates(origin, normal, m_xr[node_list[4]]);
      const auto x4 = symmetrize_coordinates(origin, normal, m_xr[node_list[3]]);
      const auto x5 = symmetrize_coordinates(origin, normal, m_xr[node_list[5]]);

      const PrismTransformation T{x0, x1, x2,   //
                                  x3, x4, x5};

      const auto& quadrature =
        QuadratureManager::instance().getPrismFormula(GaussQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    case CellType::Pyramid: {
      const auto x0 = symmetrize_coordinates(origin, normal, m_xr[node_list[3]]);
      const auto x1 = symmetrize_coordinates(origin, normal, m_xr[node_list[2]]);
      const auto x2 = symmetrize_coordinates(origin, normal, m_xr[node_list[1]]);
      const auto x3 = symmetrize_coordinates(origin, normal, m_xr[node_list[0]]);
      const auto x4 = symmetrize_coordinates(origin, normal, m_xr[node_list[4]]);
      const PyramidTransformation T{x0, x1, x2, x3, x4};

      const auto& quadrature =
        QuadratureManager::instance().getPyramidFormula(GaussQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    case CellType::Hexahedron: {
      const auto x0 = symmetrize_coordinates(origin, normal, m_xr[node_list[3]]);
      const auto x1 = symmetrize_coordinates(origin, normal, m_xr[node_list[2]]);
      const auto x2 = symmetrize_coordinates(origin, normal, m_xr[node_list[1]]);
      const auto x3 = symmetrize_coordinates(origin, normal, m_xr[node_list[0]]);
      const auto x4 = symmetrize_coordinates(origin, normal, m_xr[node_list[7]]);
      const auto x5 = symmetrize_coordinates(origin, normal, m_xr[node_list[6]]);
      const auto x6 = symmetrize_coordinates(origin, normal, m_xr[node_list[5]]);
      const auto x7 = symmetrize_coordinates(origin, normal, m_xr[node_list[4]]);

      const CubeTransformation T{x0, x1, x2, x3, x4, x5, x6, x7};

      const auto& quadrature =
        QuadratureManager::instance().getCubeFormula(GaussLegendreQuadratureDescriptor{m_polynomial_degree + 1});

      this->_computeEjkMean(quadrature, T, Xj, Vi, mean_of_ejk);
      break;
    }
    default: {
      throw NotImplementedError("unexpected cell type: " + std::string{name(cell_type)});
    }
    }
  }
}

template <MeshConcept MeshTypeT>
void
PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<MeshTypeT>::build(
  const CellId cell_j_id,
  ShrinkMatrixView<SmallMatrix<double>>& A)
{
  const auto& stencil_cell_list = m_stencil_array[cell_j_id];

  const Rd& Xj = m_xj[cell_j_id];

  this->_computeEjkMean(Xj, cell_j_id, m_mean_j_of_ejk);

  size_t index = 0;
  for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
    const CellId cell_i_id = stencil_cell_list[i];

    this->_computeEjkMean(Xj, cell_i_id, m_mean_i_of_ejk);

    for (size_t l = 0; l < m_basis_dimension - 1; ++l) {
      A(index, l) = m_mean_i_of_ejk[l] - m_mean_j_of_ejk[l];
    }
  }

  for (size_t i_symmetry = 0; i_symmetry < m_stencil_array.symmetryBoundaryStencilArrayList().size(); ++i_symmetry) {
    auto& ghost_stencil  = m_stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
    auto ghost_cell_list = ghost_stencil[cell_j_id];

    const Rd& origin = m_symmetry_origin_list[i_symmetry];
    const Rd& normal = m_symmetry_normal_list[i_symmetry];

    for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
      const CellId cell_i_id = ghost_cell_list[i];

      this->_computeEjkMeanInSymmetricCell(origin, normal, Xj, cell_i_id, m_mean_i_of_ejk);

      for (size_t l = 0; l < m_basis_dimension - 1; ++l) {
        A(index, l) = m_mean_i_of_ejk[l] - m_mean_j_of_ejk[l];
      }
    }
  }
}

template <MeshConcept MeshTypeT>
PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<
  MeshTypeT>::ElementIntegralReconstructionMatrixBuilder(const MeshType& mesh,
                                                         const size_t polynomial_degree,
                                                         const SmallArray<const Rd>& symmetry_origin_list,
                                                         const SmallArray<const Rd>& symmetry_normal_list,
                                                         const CellToCellStencilArray& stencil_array)
  : m_basis_dimension{DiscreteFunctionDPk<MeshType::Dimension, double>::BasisViewType::dimensionFromDegree(
      polynomial_degree)},
    m_polynomial_degree{polynomial_degree},

    m_wq_detJ_ek{m_basis_dimension},
    m_mean_j_of_ejk{m_basis_dimension - 1},
    m_mean_i_of_ejk{m_basis_dimension - 1},

    m_cell_to_node_matrix{mesh.connectivity().cellToNodeMatrix()},
    m_stencil_array{stencil_array},
    m_symmetry_origin_list{symmetry_origin_list},
    m_symmetry_normal_list{symmetry_normal_list},
    m_cell_type{mesh.connectivity().cellType()},
    m_Vj{MeshDataManager::instance().getMeshData(mesh).Vj()},
    m_xj{MeshDataManager::instance().getMeshData(mesh).xj()},
    m_xr{mesh.xr()}
{
  if constexpr (MeshType::Dimension == 2) {
    SmallArray<size_t> y_row_index(m_polynomial_degree + 1);

    size_t i_y = 0;

    y_row_index[i_y++] = 0;
    for (ssize_t n = m_polynomial_degree + 1; n > 1; --n, ++i_y) {
      y_row_index[i_y] = y_row_index[i_y - 1] + n;
    }

    m_y_row_index = y_row_index;

  } else if constexpr (MeshType::Dimension == 3) {
    SmallArray<size_t> yz_row_index((m_polynomial_degree + 2) * (m_polynomial_degree + 1) / 2 + 1);
    SmallArray<size_t> z_triangle_index(m_polynomial_degree + 1);

    {
      size_t i_z  = 0;
      size_t i_yz = 0;

      yz_row_index[i_yz++] = 0;
      for (ssize_t n = m_polynomial_degree + 1; n >= 1; --n) {
        z_triangle_index[i_z++] = i_yz - 1;
        for (ssize_t i = n; i >= 1; --i) {
          yz_row_index[i_yz] = yz_row_index[i_yz - 1] + i;
          ++i_yz;
        }
      }
    }

    SmallArray<size_t> yz_row_size{yz_row_index.size() - 1};
    for (size_t i = 0; i < yz_row_size.size(); ++i) {
      yz_row_size[i] = yz_row_index[i + 1] - yz_row_index[i];
    }

    m_yz_row_index     = yz_row_index;
    m_z_triangle_index = z_triangle_index;
    m_yz_row_size      = yz_row_size;
  }
}

template void PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<Mesh<1>>::build(
  const CellId,
  ShrinkMatrixView<SmallMatrix<double>>&);

template void PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<Mesh<2>>::build(
  const CellId,
  ShrinkMatrixView<SmallMatrix<double>>&);

template void PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<Mesh<3>>::build(
  const CellId,
  ShrinkMatrixView<SmallMatrix<double>>&);

template PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<
  Mesh<1>>::ElementIntegralReconstructionMatrixBuilder(const MeshType&,
                                                       const size_t,
                                                       const SmallArray<const Rd>&,
                                                       const SmallArray<const Rd>&,
                                                       const CellToCellStencilArray&);

template PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<
  Mesh<2>>::ElementIntegralReconstructionMatrixBuilder(const MeshType&,
                                                       const size_t,
                                                       const SmallArray<const Rd>&,
                                                       const SmallArray<const Rd>&,
                                                       const CellToCellStencilArray&);

template PolynomialReconstruction::ElementIntegralReconstructionMatrixBuilder<
  Mesh<3>>::ElementIntegralReconstructionMatrixBuilder(const MeshType&,
                                                       const size_t,
                                                       const SmallArray<const Rd>&,
                                                       const SmallArray<const Rd>&,
                                                       const CellToCellStencilArray&);
