#include <scheme/RusanovEulerianCompositeSolverTools.hpp>

template <class Rd>
double
toolsCompositeSolver::EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(const Rd& U_mean,
                                                                             const double& c_mean,
                                                                             const Rd& normal)
{
  const double norme_normal = l2Norm(normal);
  Rd unit_normal            = normal;
  unit_normal *= 1. / norme_normal;
  const double uscaln = dot(U_mean, unit_normal);

  return std::max(std::fabs(uscaln - c_mean) * norme_normal, std::fabs(uscaln + c_mean) * norme_normal);
}

template <class Rd>
std::pair<double, double>
toolsCompositeSolver::EvaluateMinMaxEigenValueTimesNormalLengthInGivenDirection(const Rd& U_mean,
                                                                                const double& c_mean,
                                                                                const Rd& normal)
{
  const double norme_normal = l2Norm(normal);
  Rd unit_normal            = normal;
  unit_normal *= 1. / norme_normal;
  const double uscaln = dot(U_mean, unit_normal);

  return {(uscaln - c_mean) * norme_normal, (uscaln + c_mean) * norme_normal};
}

double
toolsCompositeSolver::compute_dt(const std::shared_ptr<const DiscreteFunctionVariant>& u_v,
                                 const std::shared_ptr<const DiscreteFunctionVariant>& c_v)
{
  const auto& c = c_v->get<DiscreteFunctionP0<const double>>();

  return std::visit(
    [&](auto&& p_mesh) -> double {
      const auto& mesh = *p_mesh;

      using MeshType                    = mesh_type_t<decltype(p_mesh)>;
      static constexpr size_t Dimension = MeshType::Dimension;
      using Rd                          = TinyVector<Dimension>;

      const auto& u = u_v->get<DiscreteFunctionP0<const Rd>>();

      if constexpr (is_polygonal_mesh_v<MeshType>) {
        const auto Vj = MeshDataManager::instance().getMeshData(mesh).Vj();
        // const auto Sj = MeshDataManager::instance().getMeshData(mesh).sumOverRLjr();

        const NodeValuePerCell<const Rd> Cjr = MeshDataManager::instance().getMeshData(mesh).Cjr();
        const NodeValuePerCell<const Rd> njr = MeshDataManager::instance().getMeshData(mesh).njr();

        const FaceValuePerCell<const Rd> Cjf = MeshDataManager::instance().getMeshData(mesh).Cjf();
        const FaceValuePerCell<const Rd> njf = MeshDataManager::instance().getMeshData(mesh).njf();

        const EdgeValuePerCell<const Rd> Cje = MeshDataManager::instance().getMeshData(mesh).Cje();
        const EdgeValuePerCell<const Rd> nje = MeshDataManager::instance().getMeshData(mesh).nje();

        const auto& cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();
        const auto& cell_to_edge_matrix = mesh.connectivity().cellToEdgeMatrix();
        const auto& cell_to_face_matrix = mesh.connectivity().cellToFaceMatrix();

        CellValue<double> local_dt{mesh.connectivity()};

        parallel_for(
          p_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) {
            const auto& cell_to_node = cell_to_node_matrix[j];
            const auto& cell_to_edge = cell_to_edge_matrix[j];
            const auto& cell_to_face = cell_to_face_matrix[j];

            double maxm(0);
            for (size_t l = 0; l < cell_to_node.size(); ++l) {
              const Rd normalCjr = Cjr(j, l);
              // maxm = std::max(maxm, EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(u, c, normalCjr));
              maxm += EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(u[j], c[j], normalCjr);
            }
            for (size_t l = 0; l < cell_to_face.size(); ++l) {
              const Rd normalCjf = Cjf(j, l);
              // maxm = std::max(maxm, EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(u, c, normalCjr));
              maxm += EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(u[j], c[j], normalCjf);
            }

            if constexpr (MeshType::Dimension == 3) {
              for (size_t l = 0; l < cell_to_edge.size(); ++l) {
                const Rd normalCje = Cje(j, l);
                // maxm = std::max(maxm, EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(u, c, normalCjr));
                maxm += EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(u[j], c[j], normalCje);
              }
            }

            local_dt[j] = Vj[j] / maxm;   //(Sj[j] * c[j]);
          });

        return min(local_dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    c.meshVariant()->variant());
}

template double toolsCompositeSolver::EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(
  const TinyVector<2>& U_mean,
  const double& c_mean,
  const TinyVector<2>& normal);

template double toolsCompositeSolver::EvaluateMaxEigenValueTimesNormalLengthInGivenDirection(
  const TinyVector<3>& U_mean,
  const double& c_mean,
  const TinyVector<3>& normal);