#include <scheme/PolynomialReconstruction.hpp>

#include <algebra/Givens.hpp>
#include <algebra/ShrinkMatrixView.hpp>
#include <algebra/SmallMatrix.hpp>
#include <geometry/SymmetryUtils.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshFlatFaceBoundary.hpp>
#include <mesh/NamedBoundaryDescriptor.hpp>
#include <mesh/StencilDescriptor.hpp>
#include <mesh/StencilManager.hpp>
#include <scheme/DiscreteFunctionDPkVariant.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <scheme/reconstruction_utils/BoundaryIntegralReconstructionMatrixBuilder.hpp>
#include <scheme/reconstruction_utils/CellCenterReconstructionMatrixBuilder.hpp>
#include <scheme/reconstruction_utils/ElementIntegralReconstructionMatrixBuilder.hpp>
#include <scheme/reconstruction_utils/MutableDiscreteFunctionDPkVariant.hpp>

template <MeshConcept MeshType>
class PolynomialReconstruction::Internal
{
 private:
  using Rd = TinyVector<MeshType::Dimension>;

  friend PolynomialReconstruction;

  template <typename MatrixType>
  static void
  buildB(const CellId& cell_j_id,
         const CellToCellStencilArray& stencil_array,
         const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list,
         SmallArray<const Rd> symmetry_normal_list,
         ShrinkMatrixView<MatrixType>& B)
  {
    auto stencil_cell_list = stencil_array[cell_j_id];

    size_t column_begin = 0;
    for (size_t i_discrete_function_variant = 0; i_discrete_function_variant < discrete_function_variant_list.size();
         ++i_discrete_function_variant) {
      const auto& discrete_function_variant = discrete_function_variant_list[i_discrete_function_variant];

      std::visit(
        [&](auto&& discrete_function) {
          using DiscreteFunctionT = std::decay_t<decltype(discrete_function)>;
          if constexpr (is_discrete_function_P0_v<DiscreteFunctionT>) {
            using DataType     = std::decay_t<typename DiscreteFunctionT::data_type>;
            const DataType& qj = discrete_function[cell_j_id];
            size_t index       = 0;
            for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
              const CellId cell_i_id = stencil_cell_list[i];
              const DataType& qi_qj  = discrete_function[cell_i_id] - qj;
              if constexpr (std::is_arithmetic_v<DataType>) {
                B(index, column_begin) = qi_qj;
              } else if constexpr (is_tiny_vector_v<DataType>) {
                for (size_t kB = column_begin, k = 0; k < DataType::Dimension; ++k, ++kB) {
                  B(index, kB) = qi_qj[k];
                }
              } else if constexpr (is_tiny_matrix_v<DataType>) {
                for (size_t p = 0; p < DataType::NumberOfRows; ++p) {
                  const size_t kB = column_begin + p * DataType::NumberOfColumns;
                  for (size_t q = 0; q < DataType::NumberOfColumns; ++q) {
                    B(index, kB + q) = qi_qj(p, q);
                  }
                }
              }
            }

            for (size_t i_symmetry = 0; i_symmetry < stencil_array.symmetryBoundaryStencilArrayList().size();
                 ++i_symmetry) {
              auto& ghost_stencil  = stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
              auto ghost_cell_list = ghost_stencil[cell_j_id];
              for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
                const CellId cell_i_id = ghost_cell_list[i];
                if constexpr (std::is_arithmetic_v<DataType>) {
                  const DataType& qi_qj  = discrete_function[cell_i_id] - qj;
                  B(index, column_begin) = qi_qj;
                } else if constexpr (is_tiny_vector_v<DataType>) {
                  if constexpr (DataType::Dimension == MeshType::Dimension) {
                    const Rd& normal = symmetry_normal_list[i_symmetry];

                    const DataType& qi    = discrete_function[cell_i_id];
                    const DataType& qi_qj = symmetrize_vector(normal, qi) - qj;
                    for (size_t kB = column_begin, k = 0; k < DataType::Dimension; ++k, ++kB) {
                      B(index, kB) = qi_qj[k];
                    }
                  } else {
                    // LCOV_EXCL_START
                    std::stringstream error_msg;
                    error_msg << "cannot symmetrize vectors of dimension " << DataType::Dimension
                              << " using a mesh of dimension " << MeshType::Dimension;
                    throw UnexpectedError(error_msg.str());
                    // LCOV_EXCL_STOP
                  }
                } else if constexpr (is_tiny_matrix_v<DataType>) {
                  if constexpr ((DataType::NumberOfColumns == DataType::NumberOfRows) and
                                (DataType::NumberOfColumns == MeshType::Dimension)) {
                    const Rd& normal = symmetry_normal_list[i_symmetry];

                    const DataType& qi    = discrete_function[cell_i_id];
                    const DataType& qi_qj = symmetrize_matrix(normal, qi) - qj;
                    for (size_t p = 0; p < DataType::NumberOfRows; ++p) {
                      for (size_t q = 0; q < DataType::NumberOfColumns; ++q) {
                        B(index, column_begin + p * DataType::NumberOfColumns + q) = qi_qj(p, q);
                      }
                    }
                  } else {
                    // LCOV_EXCL_START
                    std::stringstream error_msg;
                    error_msg << "cannot symmetrize matrices of dimensions " << DataType::NumberOfRows << 'x'
                              << DataType::NumberOfColumns << " using a mesh of dimension " << MeshType::Dimension;
                    throw UnexpectedError(error_msg.str());
                    // LCOV_EXCL_STOP
                  }
                }
              }
            }

            if constexpr (std::is_arithmetic_v<DataType>) {
              ++column_begin;
            } else if constexpr (is_tiny_vector_v<DataType> or is_tiny_matrix_v<DataType>) {
              column_begin += DataType::Dimension;
            }
          } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
            using DataType = std::decay_t<typename DiscreteFunctionT::data_type>;

            const auto qj_vector = discrete_function[cell_j_id];

            if constexpr (std::is_arithmetic_v<DataType>) {
              size_t index = 0;
              for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
                const CellId cell_i_id = stencil_cell_list[i];
                for (size_t l = 0; l < qj_vector.size(); ++l) {
                  const DataType& qj         = qj_vector[l];
                  const DataType& qi_qj      = discrete_function[cell_i_id][l] - qj;
                  B(index, column_begin + l) = qi_qj;
                }
              }

              for (size_t i_symmetry = 0; i_symmetry < stencil_array.symmetryBoundaryStencilArrayList().size();
                   ++i_symmetry) {
                auto& ghost_stencil  = stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
                auto ghost_cell_list = ghost_stencil[cell_j_id];
                for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
                  const CellId cell_i_id = ghost_cell_list[i];
                  for (size_t l = 0; l < qj_vector.size(); ++l) {
                    const DataType& qj         = qj_vector[l];
                    const DataType& qi_qj      = discrete_function[cell_i_id][l] - qj;
                    B(index, column_begin + l) = qi_qj;
                  }
                }
              }
            } else if constexpr (is_tiny_vector_v<DataType>) {
              size_t index = 0;
              for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
                const CellId cell_i_id = stencil_cell_list[i];
                for (size_t l = 0; l < qj_vector.size(); ++l) {
                  const DataType& qj    = qj_vector[l];
                  const DataType& qi_qj = discrete_function[cell_i_id][l] - qj;
                  for (size_t kB = column_begin + l * DataType::Dimension, k = 0; k < DataType::Dimension; ++k, ++kB) {
                    B(index, kB) = qi_qj[k];
                  }
                }
              }

              for (size_t i_symmetry = 0; i_symmetry < stencil_array.symmetryBoundaryStencilArrayList().size();
                   ++i_symmetry) {
                if constexpr (DataType::Dimension == MeshType::Dimension) {
                  auto& ghost_stencil  = stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
                  auto ghost_cell_list = ghost_stencil[cell_j_id];

                  const Rd& normal = symmetry_normal_list[i_symmetry];

                  for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
                    const CellId cell_i_id = ghost_cell_list[i];

                    for (size_t l = 0; l < qj_vector.size(); ++l) {
                      const DataType& qj    = qj_vector[l];
                      const DataType& qi    = discrete_function[cell_i_id][l];
                      const DataType& qi_qj = symmetrize_vector(normal, qi) - qj;
                      for (size_t kB = column_begin + l * DataType::Dimension, k = 0; k < DataType::Dimension;
                           ++k, ++kB) {
                        B(index, kB) = qi_qj[k];
                      }
                    }
                  }
                } else {
                  // LCOV_EXCL_START
                  std::stringstream error_msg;
                  error_msg << "cannot symmetrize vectors of dimension " << DataType::Dimension
                            << " using a mesh of dimension " << MeshType::Dimension;
                  throw UnexpectedError(error_msg.str());
                  // LCOV_EXCL_STOP
                }
              }
            } else if constexpr (is_tiny_matrix_v<DataType>) {
              size_t index = 0;
              for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
                const CellId cell_i_id = stencil_cell_list[i];
                for (size_t l = 0; l < qj_vector.size(); ++l) {
                  const DataType& qj    = qj_vector[l];
                  const DataType& qi    = discrete_function[cell_i_id][l];
                  const DataType& qi_qj = qi - qj;

                  for (size_t p = 0; p < DataType::NumberOfRows; ++p) {
                    const size_t kB = column_begin + l * DataType::Dimension + p * DataType::NumberOfColumns;
                    for (size_t q = 0; q < DataType::NumberOfColumns; ++q) {
                      B(index, kB + q) = qi_qj(p, q);
                    }
                  }
                }
              }

              for (size_t i_symmetry = 0; i_symmetry < stencil_array.symmetryBoundaryStencilArrayList().size();
                   ++i_symmetry) {
                if constexpr ((DataType::NumberOfRows == MeshType::Dimension) and
                              (DataType::NumberOfColumns == MeshType::Dimension)) {
                  auto& ghost_stencil  = stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
                  auto ghost_cell_list = ghost_stencil[cell_j_id];

                  const Rd& normal = symmetry_normal_list[i_symmetry];

                  for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
                    const CellId cell_i_id = ghost_cell_list[i];

                    for (size_t l = 0; l < qj_vector.size(); ++l) {
                      const DataType& qj    = qj_vector[l];
                      const DataType& qi    = discrete_function[cell_i_id][l];
                      const DataType& qi_qj = symmetrize_matrix(normal, qi) - qj;

                      for (size_t p = 0; p < DataType::NumberOfRows; ++p) {
                        const size_t kB = column_begin + l * DataType::Dimension + p * DataType::NumberOfColumns;
                        for (size_t q = 0; q < DataType::NumberOfColumns; ++q) {
                          B(index, kB + q) = qi_qj(p, q);
                        }
                      }
                    }
                  }
                } else {
                  // LCOV_EXCL_START
                  std::stringstream error_msg;
                  error_msg << "cannot symmetrize vectors of dimension " << DataType::Dimension
                            << " using a mesh of dimension " << MeshType::Dimension;
                  throw UnexpectedError(error_msg.str());
                  // LCOV_EXCL_STOP
                }
              }
            }

            if constexpr (std::is_arithmetic_v<DataType>) {
              column_begin += qj_vector.size();
            } else if constexpr (is_tiny_vector_v<DataType> or is_tiny_matrix_v<DataType>) {
              column_begin += qj_vector.size() * DataType::Dimension;
            }

          } else {
            // LCOV_EXCL_START
            throw UnexpectedError("invalid discrete function type");
            // LCOV_EXCL_STOP
          }
        },
        discrete_function_variant->discreteFunction());
    }
  }

  template <typename MatrixType>
  static void
  rowWeighting(const CellId& cell_j_id,
               const CellToCellStencilArray& stencil_array,
               const CellValue<const Rd>& xj,
               const SmallArray<const Rd>& symmetry_origin_list,
               const SmallArray<const Rd>& symmetry_normal_list,
               ShrinkMatrixView<MatrixType>& A,
               ShrinkMatrixView<MatrixType>& B)
  {
    // Add row weighting (give more importance to cells that are
    // closer to j)
    auto stencil_cell_list = stencil_array[cell_j_id];

    const Rd& Xj = xj[cell_j_id];

    size_t index = 0;
    for (size_t i = 0; i < stencil_cell_list.size(); ++i, ++index) {
      const CellId cell_i_id = stencil_cell_list[i];
      const double weight    = 1. / l2Norm(xj[cell_i_id] - Xj);
      for (size_t l = 0; l < A.numberOfColumns(); ++l) {
        A(index, l) *= weight;
      }
      for (size_t l = 0; l < B.numberOfColumns(); ++l) {
        B(index, l) *= weight;
      }
    }
    for (size_t i_symmetry = 0; i_symmetry < stencil_array.symmetryBoundaryStencilArrayList().size(); ++i_symmetry) {
      auto& ghost_stencil  = stencil_array.symmetryBoundaryStencilArrayList()[i_symmetry].stencilArray();
      auto ghost_cell_list = ghost_stencil[cell_j_id];

      const Rd& origin = symmetry_origin_list[i_symmetry];
      const Rd& normal = symmetry_normal_list[i_symmetry];

      for (size_t i = 0; i < ghost_cell_list.size(); ++i, ++index) {
        const CellId cell_i_id = ghost_cell_list[i];
        const double weight    = 1. / l2Norm(symmetrize_coordinates(origin, normal, xj[cell_i_id]) - Xj);
        for (size_t l = 0; l < A.numberOfColumns(); ++l) {
          A(index, l) *= weight;
        }
        for (size_t l = 0; l < B.numberOfColumns(); ++l) {
          B(index, l) *= weight;
        }
      }
    }
  }

  template <typename MatrixType>
  static void
  solveCollectionInPlaceWithPreconditionner(const ShrinkMatrixView<MatrixType>& A,
                                            const SmallMatrix<double>& X,
                                            const ShrinkMatrixView<MatrixType>& B,
                                            const SmallVector<double>& G)
  {
    for (size_t l = 0; l < A.numberOfColumns(); ++l) {
      double g = 0;
      for (size_t i = 0; i < A.numberOfRows(); ++i) {
        const double Ail = A(i, l);

        g += Ail * Ail;
      }
      G[l] = std::sqrt(g);
    }

    for (size_t l = 0; l < A.numberOfColumns(); ++l) {
      const double Gl = G[l];
      for (size_t i = 0; i < A.numberOfRows(); ++i) {
        A(i, l) *= Gl;
      }
    }

    Givens::solveCollectionInPlace(A, X, B);

    for (size_t l = 0; l < X.numberOfRows(); ++l) {
      const double Gl = G[l];
      for (size_t i = 0; i < X.numberOfColumns(); ++i) {
        X(l, i) *= Gl;
      }
    }
  }

  template <typename ReconstructionMatrixBuilderType>
  static void
  populateDiscreteFunctionDPkByCell(
    const CellId& cell_j_id,
    const size_t& degree,
    const SmallMatrix<double>& X,
    const ReconstructionMatrixBuilderType& reconstruction_matrix_builder,
    const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list,
    const std::vector<PolynomialReconstruction::MutableDiscreteFunctionDPkVariant>&
      mutable_discrete_function_dpk_variant_list)
  {
    size_t column_begin = 0;
    for (size_t i_dpk_variant = 0; i_dpk_variant < mutable_discrete_function_dpk_variant_list.size(); ++i_dpk_variant) {
      const auto& dpk_variant = mutable_discrete_function_dpk_variant_list[i_dpk_variant];

      const auto& discrete_function_variant = discrete_function_variant_list[i_dpk_variant];

      std::visit(
        [&](auto&& dpk_function, auto&& p0_function) {
          using DPkFunctionT = std::decay_t<decltype(dpk_function)>;
          using P0FunctionT  = std::decay_t<decltype(p0_function)>;
          using DataType     = std::remove_const_t<std::decay_t<typename DPkFunctionT::data_type>>;
          using P0DataType   = std::remove_const_t<std::decay_t<typename P0FunctionT::data_type>>;

          if constexpr (std::is_same_v<DataType, P0DataType>) {
            if constexpr (is_discrete_function_P0_v<P0FunctionT>) {
              if constexpr (is_discrete_function_dpk_scalar_v<DPkFunctionT>) {
                auto dpk_j = dpk_function.coefficients(cell_j_id);
                dpk_j[0]   = p0_function[cell_j_id];

                if constexpr (std::is_arithmetic_v<DataType>) {
                  if constexpr (ReconstructionMatrixBuilderType::handles_high_degrees) {
                    if (degree > 1) {
                      const auto& mean_j_of_ejk = reconstruction_matrix_builder.meanjOfEjk();
                      for (size_t i = 0; i < X.numberOfRows(); ++i) {
                        dpk_j[0] -= X(i, column_begin) * mean_j_of_ejk[i];
                      }
                    }
                  }

                  for (size_t i = 0; i < X.numberOfRows(); ++i) {
                    auto& dpk_j_ip1 = dpk_j[i + 1];
                    dpk_j_ip1       = X(i, column_begin);
                  }
                  ++column_begin;
                } else if constexpr (is_tiny_vector_v<DataType>) {
                  if constexpr (ReconstructionMatrixBuilderType::handles_high_degrees) {
                    if (degree > 1) {
                      const auto& mean_j_of_ejk = reconstruction_matrix_builder.meanjOfEjk();
                      for (size_t i = 0; i < X.numberOfRows(); ++i) {
                        auto& dpk_j_0 = dpk_j[0];
                        for (size_t k = 0; k < DataType::Dimension; ++k) {
                          dpk_j_0[k] -= X(i, column_begin + k) * mean_j_of_ejk[i];
                        }
                      }
                    }
                  }

                  for (size_t i = 0; i < X.numberOfRows(); ++i) {
                    auto& dpk_j_ip1 = dpk_j[i + 1];
                    for (size_t k = 0; k < DataType::Dimension; ++k) {
                      dpk_j_ip1[k] = X(i, column_begin + k);
                    }
                  }
                  column_begin += DataType::Dimension;
                } else if constexpr (is_tiny_matrix_v<DataType>) {
                  if constexpr (ReconstructionMatrixBuilderType::handles_high_degrees) {
                    if (degree > 1) {
                      const auto& mean_j_of_ejk = reconstruction_matrix_builder.meanjOfEjk();
                      for (size_t i = 0; i < X.numberOfRows(); ++i) {
                        auto& dpk_j_0 = dpk_j[0];
                        for (size_t k = 0; k < DataType::NumberOfRows; ++k) {
                          for (size_t l = 0; l < DataType::NumberOfColumns; ++l) {
                            dpk_j_0(k, l) -= X(i, column_begin + k * DataType::NumberOfColumns + l) * mean_j_of_ejk[i];
                          }
                        }
                      }
                    }
                  }

                  for (size_t i = 0; i < X.numberOfRows(); ++i) {
                    auto& dpk_j_ip1 = dpk_j[i + 1];
                    for (size_t k = 0; k < DataType::NumberOfRows; ++k) {
                      for (size_t l = 0; l < DataType::NumberOfColumns; ++l) {
                        dpk_j_ip1(k, l) = X(i, column_begin + k * DataType::NumberOfColumns + l);
                      }
                    }
                  }
                  column_begin += DataType::Dimension;
                } else {
                  // LCOV_EXCL_START
                  throw UnexpectedError("unexpected data type");
                  // LCOV_EXCL_STOP
                }
              } else {
                // LCOV_EXCL_START
                throw UnexpectedError("unexpected discrete dpk function type");
                // LCOV_EXCL_STOP
              }
            } else if constexpr (is_discrete_function_P0_vector_v<P0FunctionT>) {
              if constexpr (is_discrete_function_dpk_vector_v<DPkFunctionT>) {
                auto dpk_j        = dpk_function.coefficients(cell_j_id);
                auto cell_vector  = p0_function[cell_j_id];
                const size_t size = X.numberOfRows() + 1;

                for (size_t l = 0; l < cell_vector.size(); ++l) {
                  const size_t component_begin = l * size;
                  dpk_j[component_begin]       = cell_vector[l];
                  if constexpr (std::is_arithmetic_v<DataType>) {
                    if constexpr (ReconstructionMatrixBuilderType::handles_high_degrees) {
                      const auto& mean_j_of_ejk = reconstruction_matrix_builder.meanjOfEjk();
                      if (degree > 1) {
                        for (size_t i = 0; i < X.numberOfRows(); ++i) {
                          dpk_j[component_begin] -= X(i, column_begin) * mean_j_of_ejk[i];
                        }
                      }
                    }

                    for (size_t i = 0; i < X.numberOfRows(); ++i) {
                      auto& dpk_j_ip1 = dpk_j[component_begin + i + 1];
                      dpk_j_ip1       = X(i, column_begin);
                    }
                    ++column_begin;
                  } else if constexpr (is_tiny_vector_v<DataType>) {
                    if constexpr (ReconstructionMatrixBuilderType::handles_high_degrees) {
                      if (degree > 1) {
                        const auto& mean_j_of_ejk = reconstruction_matrix_builder.meanjOfEjk();
                        for (size_t i = 0; i < X.numberOfRows(); ++i) {
                          auto& dpk_j_0 = dpk_j[component_begin];
                          for (size_t k = 0; k < DataType::Dimension; ++k) {
                            dpk_j_0[k] -= X(i, column_begin + k) * mean_j_of_ejk[i];
                          }
                        }
                      }
                    }

                    for (size_t i = 0; i < X.numberOfRows(); ++i) {
                      auto& dpk_j_ip1 = dpk_j[component_begin + i + 1];
                      for (size_t k = 0; k < DataType::Dimension; ++k) {
                        dpk_j_ip1[k] = X(i, column_begin + k);
                      }
                    }
                    column_begin += DataType::Dimension;
                  } else if constexpr (is_tiny_matrix_v<DataType>) {
                    if constexpr (ReconstructionMatrixBuilderType::handles_high_degrees) {
                      if (degree > 1) {
                        const auto& mean_j_of_ejk = reconstruction_matrix_builder.meanjOfEjk();
                        for (size_t i = 0; i < X.numberOfRows(); ++i) {
                          auto& dpk_j_0 = dpk_j[component_begin];
                          for (size_t p = 0; p < DataType::NumberOfRows; ++p) {
                            for (size_t q = 0; q < DataType::NumberOfColumns; ++q) {
                              dpk_j_0(p, q) -=
                                X(i, column_begin + p * DataType::NumberOfColumns + q) * mean_j_of_ejk[i];
                            }
                          }
                        }
                      }
                    }

                    for (size_t i = 0; i < X.numberOfRows(); ++i) {
                      auto& dpk_j_ip1 = dpk_j[component_begin + i + 1];
                      for (size_t p = 0; p < DataType::NumberOfRows; ++p) {
                        for (size_t q = 0; q < DataType::NumberOfColumns; ++q) {
                          dpk_j_ip1(p, q) = X(i, column_begin + p * DataType::NumberOfColumns + q);
                        }
                      }
                    }
                    column_begin += DataType::Dimension;
                  } else {
                    // LCOV_EXCL_START
                    throw UnexpectedError("unexpected data type");
                    // LCOV_EXCL_STOP
                  }
                }
              } else {
                // LCOV_EXCL_START
                throw UnexpectedError("unexpected discrete dpk function type");
                // LCOV_EXCL_STOP
              }
            } else {
              // LCOV_EXCL_START
              throw UnexpectedError("unexpected discrete function type");
              // LCOV_EXCL_STOP
            }
          } else {
            // LCOV_EXCL_START
            throw UnexpectedError("incompatible data types");
            // LCOV_EXCL_STOP
          }
        },
        dpk_variant.mutableDiscreteFunctionDPk(), discrete_function_variant->discreteFunction());
    }
  }
};

size_t
PolynomialReconstruction::_getNumberOfColumns(
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
{
  size_t number_of_columns = 0;
  for (auto discrete_function_variant_p : discrete_function_variant_list) {
    number_of_columns += std::visit(
      [](auto&& discrete_function) -> size_t {
        using DiscreteFunctionT = std::decay_t<decltype(discrete_function)>;
        if constexpr (is_discrete_function_P0_v<DiscreteFunctionT>) {
          using data_type = std::decay_t<typename DiscreteFunctionT::data_type>;
          if constexpr (std::is_arithmetic_v<data_type>) {
            return 1;
          } else if constexpr (is_tiny_vector_v<data_type> or is_tiny_matrix_v<data_type>) {
            return data_type::Dimension;
          } else {
            // LCOV_EXCL_START
            throw UnexpectedError("unexpected data type " + demangle<data_type>());
            // LCOV_EXCL_STOP
          }
        } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
          using data_type = std::decay_t<typename DiscreteFunctionT::data_type>;
          if constexpr (std::is_arithmetic_v<data_type>) {
            return discrete_function.size();
          } else if constexpr (is_tiny_vector_v<data_type> or is_tiny_matrix_v<data_type>) {
            return discrete_function.size() * data_type::Dimension;
          } else {
            // LCOV_EXCL_START
            throw UnexpectedError("unexpected data type " + demangle<data_type>());
            // LCOV_EXCL_STOP
          }
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected discrete function type");
          // LCOV_EXCL_STOP
        }
      },
      discrete_function_variant_p->discreteFunction());
  }
  return number_of_columns;
}

template <MeshConcept MeshType>
std::vector<PolynomialReconstruction::MutableDiscreteFunctionDPkVariant>
PolynomialReconstruction::_createMutableDiscreteFunctionDPKVariantList(
  const std::shared_ptr<const MeshType>& p_mesh,
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
{
  std::vector<MutableDiscreteFunctionDPkVariant> mutable_discrete_function_dpk_variant_list;
  for (size_t i_discrete_function_variant = 0; i_discrete_function_variant < discrete_function_variant_list.size();
       ++i_discrete_function_variant) {
    auto discrete_function_variant = discrete_function_variant_list[i_discrete_function_variant];

    std::visit(
      [&](auto&& discrete_function) {
        using DiscreteFunctionT = std::decay_t<decltype(discrete_function)>;
        if constexpr (is_discrete_function_P0_v<DiscreteFunctionT>) {
          using DataType = std::remove_const_t<std::decay_t<typename DiscreteFunctionT::data_type>>;
          mutable_discrete_function_dpk_variant_list.push_back(
            DiscreteFunctionDPk<MeshType::Dimension, DataType>(p_mesh, m_descriptor.degree()));
        } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
          using DataType = std::remove_const_t<std::decay_t<typename DiscreteFunctionT::data_type>>;
          mutable_discrete_function_dpk_variant_list.push_back(
            DiscreteFunctionDPkVector<MeshType::Dimension, DataType>(p_mesh, m_descriptor.degree(),
                                                                     discrete_function.size()));
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected discrete function type");
          // LCOV_EXCL_STOP
        }
      },
      discrete_function_variant->discreteFunction());
  }

  return mutable_discrete_function_dpk_variant_list;
}

template <MeshConcept MeshType>
void
PolynomialReconstruction::_checkDataAndSymmetriesCompatibility(
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
{
  for (auto&& discrete_function_variant : discrete_function_variant_list) {
    std::visit(
      [&](auto&& discrete_function) {
        using DiscreteFunctionT = std::decay_t<decltype(discrete_function)>;
        if constexpr (is_discrete_function_P0_v<DiscreteFunctionT> or
                      is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
          using DataType = std::decay_t<typename DiscreteFunctionT::data_type>;
          if constexpr (is_tiny_vector_v<DataType>) {
            if constexpr (DataType::Dimension != MeshType::Dimension) {
              std::stringstream error_msg;
              error_msg << "cannot symmetrize vectors of dimension " << DataType::Dimension
                        << " using a mesh of dimension " << MeshType::Dimension;
              throw NormalError(error_msg.str());
            }
          } else if constexpr (is_tiny_matrix_v<DataType>) {
            if constexpr ((DataType::NumberOfRows != MeshType::Dimension) or
                          (DataType::NumberOfColumns != MeshType::Dimension)) {
              std::stringstream error_msg;
              error_msg << "cannot symmetrize matrices of dimensions " << DataType::NumberOfRows << 'x'
                        << DataType::NumberOfColumns << " using a mesh of dimension " << MeshType::Dimension;
              throw NormalError(error_msg.str());
            }
          }
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("invalid discrete function type");
          // LCOV_EXCL_STOP
        }
      },
      discrete_function_variant->discreteFunction());
  }
}

template <typename ReconstructionMatrixBuilderType, MeshConcept MeshType>
[[nodiscard]] std::vector<std::shared_ptr<const DiscreteFunctionDPkVariant>>
PolynomialReconstruction::_build(
  const std::shared_ptr<const MeshType>& p_mesh,
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
{
  static_assert(std::is_same_v<MeshType, typename ReconstructionMatrixBuilderType::MeshType>);

  const MeshType& mesh = *p_mesh;

  using Rd = TinyVector<MeshType::Dimension>;

  if (m_descriptor.symmetryBoundaryDescriptorList().size() > 0) {
    this->_checkDataAndSymmetriesCompatibility<MeshType>(discrete_function_variant_list);
  }

  const size_t number_of_columns = this->_getNumberOfColumns(discrete_function_variant_list);

  const size_t basis_dimension =
    DiscreteFunctionDPk<MeshType::Dimension, double>::BasisViewType::dimensionFromDegree(m_descriptor.degree());

  const auto& stencil_array =
    StencilManager::instance().getCellToCellStencilArray(mesh.connectivity(), m_descriptor.stencilDescriptor(),
                                                         m_descriptor.symmetryBoundaryDescriptorList());

  auto xr = mesh.xr();
  auto xj = MeshDataManager::instance().getMeshData(mesh).xj();
  auto Vj = MeshDataManager::instance().getMeshData(mesh).Vj();

  auto cell_is_owned = mesh.connectivity().cellIsOwned();
  auto cell_type     = mesh.connectivity().cellType();

  auto full_stencil_size = [&](const CellId cell_id) {
    auto stencil_cell_list = stencil_array[cell_id];
    size_t stencil_size    = stencil_cell_list.size();
    for (size_t i = 0; i < m_descriptor.symmetryBoundaryDescriptorList().size(); ++i) {
      auto& ghost_stencil = stencil_array.symmetryBoundaryStencilArrayList()[i].stencilArray();
      stencil_size += ghost_stencil[cell_id].size();
    }

    return stencil_size;
  };

  const size_t max_stencil_size = [&]() {
    size_t max_size = 0;
    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      const size_t stencil_size = full_stencil_size(cell_id);
      if (cell_is_owned[cell_id] and stencil_size > max_size) {
        max_size = stencil_size;
      }
    }
    return max_size;
  }();

  SmallArray<const Rd> symmetry_normal_list = [&] {
    SmallArray<Rd> normal_list(m_descriptor.symmetryBoundaryDescriptorList().size());
    size_t i_symmetry_boundary = 0;
    for (auto p_boundary_descriptor : m_descriptor.symmetryBoundaryDescriptorList()) {
      const IBoundaryDescriptor& boundary_descriptor = *p_boundary_descriptor;

      auto symmetry_boundary             = getMeshFlatFaceBoundary(mesh, boundary_descriptor);
      normal_list[i_symmetry_boundary++] = symmetry_boundary.outgoingNormal();
    }
    return normal_list;
  }();

  SmallArray<const Rd> symmetry_origin_list = [&] {
    SmallArray<Rd> origin_list(m_descriptor.symmetryBoundaryDescriptorList().size());
    size_t i_symmetry_boundary = 0;
    for (auto p_boundary_descriptor : m_descriptor.symmetryBoundaryDescriptorList()) {
      const IBoundaryDescriptor& boundary_descriptor = *p_boundary_descriptor;

      auto symmetry_boundary             = getMeshFlatFaceBoundary(mesh, boundary_descriptor);
      origin_list[i_symmetry_boundary++] = symmetry_boundary.origin();
    }
    return origin_list;
  }();

  Kokkos::Experimental::UniqueToken<Kokkos::DefaultExecutionSpace::execution_space,
                                    Kokkos::Experimental::UniqueTokenScope::Global>
    tokens;

  auto mutable_discrete_function_dpk_variant_list =
    this->_createMutableDiscreteFunctionDPKVariantList(p_mesh, discrete_function_variant_list);

  SmallArray<SmallMatrix<double>> A_pool(Kokkos::DefaultExecutionSpace::concurrency());
  SmallArray<SmallMatrix<double>> B_pool(Kokkos::DefaultExecutionSpace::concurrency());
  SmallArray<SmallVector<double>> G_pool(Kokkos::DefaultExecutionSpace::concurrency());
  SmallArray<SmallMatrix<double>> X_pool(Kokkos::DefaultExecutionSpace::concurrency());

  for (size_t i = 0; i < A_pool.size(); ++i) {
    A_pool[i] = SmallMatrix<double>(max_stencil_size, basis_dimension - 1);
    B_pool[i] = SmallMatrix<double>(max_stencil_size, number_of_columns);
    G_pool[i] = SmallVector<double>(basis_dimension - 1);
    X_pool[i] = SmallMatrix<double>(basis_dimension - 1, number_of_columns);
  }

  SmallArray<std::shared_ptr<ReconstructionMatrixBuilderType>> reconstruction_matrix_builder_pool(A_pool.size());

  for (size_t t = 0; t < reconstruction_matrix_builder_pool.size(); ++t) {
    reconstruction_matrix_builder_pool[t] =
      std::make_shared<ReconstructionMatrixBuilderType>(*p_mesh, m_descriptor.degree(), symmetry_origin_list,
                                                        symmetry_normal_list, stencil_array);
  }

  parallel_for(
    mesh.numberOfCells(), PUGS_CLASS_LAMBDA(const CellId cell_j_id) {
      if (cell_is_owned[cell_j_id]) {
        const int32_t t = tokens.acquire();

        ShrinkMatrixView A(A_pool[t], full_stencil_size(cell_j_id));
        ShrinkMatrixView B(B_pool[t], full_stencil_size(cell_j_id));

        Internal<MeshType>::buildB(cell_j_id, stencil_array, discrete_function_variant_list, symmetry_normal_list, B);

        ReconstructionMatrixBuilderType& reconstruction_matrix_builder = *reconstruction_matrix_builder_pool[t];
        reconstruction_matrix_builder.build(cell_j_id, A);

        if (m_descriptor.rowWeighting()) {
          Internal<MeshType>::rowWeighting(cell_j_id, stencil_array, xj, symmetry_origin_list, symmetry_normal_list, A,
                                           B);
        }

        const SmallMatrix<double>& X = X_pool[t];

        if (m_descriptor.preconditioning()) {
          // Add column  weighting preconditioning (increase the precision)
          SmallVector<double>& G = G_pool[t];

          Internal<MeshType>::solveCollectionInPlaceWithPreconditionner(A, X, B, G);
        } else {
          Givens::solveCollectionInPlace(A, X, B);
        }

        Internal<MeshType>::template populateDiscreteFunctionDPkByCell(cell_j_id, m_descriptor.degree(), X,
                                                                       reconstruction_matrix_builder,
                                                                       discrete_function_variant_list,
                                                                       mutable_discrete_function_dpk_variant_list);

        tokens.release(t);
      }
    });

  std::vector<std::shared_ptr<const DiscreteFunctionDPkVariant>> discrete_function_dpk_variant_list;

  for (auto discrete_function_dpk_variant_p : mutable_discrete_function_dpk_variant_list) {
    std::visit(
      [&](auto&& mutable_function_dpk) {
        synchronize(mutable_function_dpk.cellArrays());
        discrete_function_dpk_variant_list.push_back(
          std::make_shared<DiscreteFunctionDPkVariant>(mutable_function_dpk));
      },
      discrete_function_dpk_variant_p.mutableDiscreteFunctionDPk());
  }

  return discrete_function_dpk_variant_list;
}

template <MeshConcept MeshType>
[[nodiscard]] std::vector<std::shared_ptr<const DiscreteFunctionDPkVariant>>
PolynomialReconstruction::_build(
  const std::shared_ptr<const MeshType>& p_mesh,
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
{
  switch (m_descriptor.integrationMethodType()) {
  case IntegrationMethodType::cell_center: {
    return this->_build<CellCenterReconstructionMatrixBuilder<MeshType>>(p_mesh, discrete_function_variant_list);
  }
  case IntegrationMethodType::boundary: {
    if constexpr (MeshType::Dimension == 2) {
      return this->_build<BoundaryIntegralReconstructionMatrixBuilder<MeshType>>(p_mesh,
                                                                                 discrete_function_variant_list);
    }
    [[fallthrough]];
  }
  case IntegrationMethodType::element: {
    return this->_build<ElementIntegralReconstructionMatrixBuilder<MeshType>>(p_mesh, discrete_function_variant_list);
  }
    // LCOV_EXCL_START
  default: {
    throw UnexpectedError("invalid reconstruction matrix builder type");
  }
    // LCOV_EXCL_STOP
  }
}

std::vector<std::shared_ptr<const DiscreteFunctionDPkVariant>>
PolynomialReconstruction::build(
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
{
  if (not hasSameMesh(discrete_function_variant_list)) {
    throw NormalError("cannot reconstruct functions living of different meshes simultaneously");
  }

  auto mesh_v = getCommonMesh(discrete_function_variant_list);

  return std::visit([&](auto&& p_mesh) { return this->_build(p_mesh, discrete_function_variant_list); },
                    mesh_v->variant());
}
