#include <scheme/GKS.hpp>

#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshTraits.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>

template <MeshConcept MeshType>
class GKS
{
 public:
  std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>,
             std::shared_ptr<const DiscreteFunctionVariant>>
  solve(std::shared_ptr<const MeshType> p_mesh,
        std::shared_ptr<const DiscreteFunctionVariant> rho_v,
        std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
        std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
        std::shared_ptr<const DiscreteFunctionVariant> tau,
        const double delta,
        double dt)
  {
    using Rd = TinyVector<MeshType::Dimension>;

    const MeshType& mesh = *p_mesh;

    const double pi = std::acos(-1);

    DiscreteFunctionP0<const double> tau_n = tau->get<DiscreteFunctionP0<const double>>();
    CellValue<double> eta(mesh.connectivity());
    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      if (tau_n[cell_id] == 0)
        eta[cell_id] = 0;
      else
        eta[cell_id] = tau_n[cell_id];   //(tau_n[cell_id] / dt) * (1 - std::exp(-dt / tau_n[cell_id]));
    }
    // std::cout << "eta = " << eta << std::endl;

    DiscreteFunctionP0<const double> rho_n   = rho_v->get<DiscreteFunctionP0<const double>>();
    DiscreteFunctionP0<const Rd> rho_U_n     = rho_U_v->get<DiscreteFunctionP0<const Rd>>();
    DiscreteFunctionP0<const double> rho_E_n = rho_E_v->get<DiscreteFunctionP0<const double>>();

    DiscreteFunctionP0<double> rho   = copy(rho_n);
    DiscreteFunctionP0<Rd> rho_U     = copy(rho_U_n);
    DiscreteFunctionP0<double> rho_E = copy(rho_E_n);

    auto& mesh_data = MeshDataManager::instance().getMeshData(mesh);
    auto Vj         = mesh_data.Vj();

    auto cell_to_node_matrix = mesh.connectivity().cellToNodeMatrix();

    NodeValue<double> rho_flux_Euler(mesh.connectivity());
    NodeValue<Rd> rho_U_flux_Euler(mesh.connectivity());
    NodeValue<double> rho_E_flux_Euler(mesh.connectivity());
    rho_flux_Euler.fill(0);
    rho_U_flux_Euler.fill(Rd(0));
    rho_E_flux_Euler.fill(0);

    NodeValue<double> rho_flux_Navier(mesh.connectivity());
    NodeValue<Rd> rho_U_flux_Navier(mesh.connectivity());
    NodeValue<double> rho_E_flux_Navier(mesh.connectivity());
    rho_flux_Navier.fill(0);
    rho_U_flux_Navier.fill(Rd(0));
    rho_E_flux_Navier.fill(0);

    NodeValue<double> rho_node(mesh.connectivity());
    NodeValue<Rd> rho_U_node(mesh.connectivity());
    NodeValue<double> rho_E_node(mesh.connectivity());
    rho_node.fill(0);
    rho_U_node.fill(Rd(0));
    rho_E_node.fill(0);
    CellValue<double> lambda{p_mesh->connectivity()};
    // lambda.fill(0);
    CellValue<Rd> U{p_mesh->connectivity()};
    // U.fill(0);

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      U[cell_id][0]   = rho_U_n[cell_id][0] / rho_n[cell_id];
      double rho_U_2  = rho_U_n[cell_id][0] * U[cell_id][0];
      lambda[cell_id] = 0.5 * (1. + delta) * rho_n[cell_id] / (2 * rho_E_n[cell_id] - rho_U_2);
    }

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      double U_2 = U[cell_id][0] * U[cell_id][0];

      double rho_cell_left = rho_n[cell_id] * (1 + std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0]));

      Rd rho_U_cell_left;
      rho_U_cell_left[0] = rho_U_n[cell_id][0] * (1. + std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) +
                           rho_n[cell_id] * std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);

      double rho_E_cell_left =
        rho_E_n[cell_id] * (1. + std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) +
        0.5 * rho_U_n[cell_id][0] * std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);

      auto node_list = cell_to_node_matrix[cell_id];

      rho_node[node_list[1]]      = 0.5 * rho_cell_left;
      rho_U_node[node_list[1]][0] = 0.5 * rho_U_cell_left[0];
      rho_E_node[node_list[1]]    = 0.5 * rho_E_cell_left;
    }

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      double U_2 = U[cell_id][0] * U[cell_id][0];

      double rho_cell_right = rho_n[cell_id] * (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0]));

      Rd rho_U_cell_right;
      rho_U_cell_right[0] = rho_U_n[cell_id][0] * (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) -
                            rho_n[cell_id] * std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);

      double rho_E_cell_right =
        rho_E_n[cell_id] * (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) -
        0.5 * rho_U_n[cell_id][0] * std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);

      auto node_list = cell_to_node_matrix[cell_id];

      rho_node[node_list[0]] += 0.5 * rho_cell_right;
      rho_U_node[node_list[0]][0] += 0.5 * rho_U_cell_right[0];
      rho_E_node[node_list[0]] += 0.5 * rho_E_cell_right;
    }

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      auto node_list      = cell_to_node_matrix[cell_id];
      double rho_U_2_node = rho_U_node[node_list[0]][0] * rho_U_node[node_list[0]][0] / rho_node[node_list[0]];

      rho_flux_Euler[node_list[0]] = rho_U_node[node_list[0]][0];
      rho_U_flux_Euler[node_list[0]][0] =
        delta * rho_U_2_node / (1. + delta) + 2 * rho_E_node[node_list[0]] / (1. + delta);
      rho_E_flux_Euler[node_list[0]] =
        rho_U_node[node_list[0]][0] / rho_node[node_list[0]] *
        ((3. + delta) * rho_E_node[node_list[0]] / (1. + delta) - rho_U_2_node / (1. + delta));
    }

    //##
    //%%%%%%%%%%%%%%%
    //##

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      double U_2     = U[cell_id][0] * U[cell_id][0];
      double rho_U_2 = rho_U_n[cell_id][0] * U[cell_id][0];

      Rd F2_fn_left;
      F2_fn_left[0] = (rho_U_2 + 0.5 * rho_n[cell_id] / lambda[cell_id]) *
                        (1. + std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) +
                      rho_U_n[cell_id][0] * std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);

      double F3_fn_left = 0.5 * rho_U_n[cell_id][0] * (U_2 + 0.5 * (delta + 3) / lambda[cell_id]) *
                            (1. + std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) +
                          0.5 * rho_n[cell_id] * (U_2 + 0.5 * (delta + 2) / lambda[cell_id]) *
                            std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);

      auto node_list = cell_to_node_matrix[cell_id];

      rho_U_flux_Navier[node_list[1]][0] = 0.5 * F2_fn_left[0];
      rho_E_flux_Navier[node_list[1]]    = 0.5 * F3_fn_left;
    }

    for (CellId cell_id = 1; cell_id < mesh.numberOfCells(); ++cell_id) {
      double U_2     = U[cell_id][0] * U[cell_id][0];
      double rho_U_2 = rho_U_n[cell_id][0] * U[cell_id][0];

      Rd F2_fn_right;
      F2_fn_right[0] = (rho_U_2 + 0.5 * rho_n[cell_id] / lambda[cell_id]) *
                         (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) -
                       rho_U_n[cell_id][0] * std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);

      double F3_fn_right = 0.5 * rho_U_n[cell_id][0] * (U_2 + 0.5 * (delta + 3) / lambda[cell_id]) *
                             (1. - std::erf(std::sqrt(lambda[cell_id]) * U[cell_id][0])) -
                           0.5 * rho_n[cell_id] * (U_2 + 0.5 * (delta + 2) / lambda[cell_id]) *
                             std::exp(-lambda[cell_id] * U_2) / std::sqrt(pi * lambda[cell_id]);
      auto node_list = cell_to_node_matrix[cell_id];

      rho_U_flux_Navier[node_list[0]][0] += 0.5 * F2_fn_right[0];
      rho_E_flux_Navier[node_list[0]] += 0.5 * F3_fn_right;
    }

    for (CellId cell_id = 1; cell_id < mesh.numberOfCells() - 1; ++cell_id) {
      auto node_list = cell_to_node_matrix[cell_id];

      const double rho_flux_Euler_sum   = (rho_flux_Euler[node_list[1]] - rho_flux_Euler[node_list[0]]);
      const Rd rho_U_flux_Euler_sum     = (rho_U_flux_Euler[node_list[1]] - rho_U_flux_Euler[node_list[0]]);
      const double rho_E_flux_Euler_sum = (rho_E_flux_Euler[node_list[1]] - rho_E_flux_Euler[node_list[0]]);

      const Rd rho_U_flux_Navier_sum     = (rho_U_flux_Navier[node_list[1]] - rho_U_flux_Navier[node_list[0]]);
      const double rho_E_flux_Navier_sum = (rho_E_flux_Navier[node_list[1]] - rho_E_flux_Navier[node_list[0]]);
      rho[cell_id] -= dt / Vj[cell_id] * (rho_flux_Euler_sum);
      rho_U[cell_id][0] -=
        dt / Vj[cell_id] *
        (rho_U_flux_Euler_sum[0] + eta[cell_id] * (rho_U_flux_Navier_sum[0] - rho_U_flux_Euler_sum[0]));
      rho_E[cell_id] -=
        dt / Vj[cell_id] * (rho_E_flux_Euler_sum + eta[cell_id] * (rho_E_flux_Navier_sum - rho_E_flux_Euler_sum));
    }
    return std::make_tuple(std::make_shared<DiscreteFunctionVariant>(rho),
                           std::make_shared<DiscreteFunctionVariant>(rho_U),
                           std::make_shared<DiscreteFunctionVariant>(rho_E));
  }

  GKS() = default;
};

std::tuple<std::shared_ptr<const DiscreteFunctionVariant>,   // rho
           std::shared_ptr<const DiscreteFunctionVariant>,   // U
           std::shared_ptr<const DiscreteFunctionVariant>>   // E
gks(std::shared_ptr<const DiscreteFunctionVariant> rho_v,
    std::shared_ptr<const DiscreteFunctionVariant> rho_U_v,
    std::shared_ptr<const DiscreteFunctionVariant> rho_E_v,
    std::shared_ptr<const DiscreteFunctionVariant> tau,
    const double delta,
    const double dt)
{
  std::shared_ptr mesh_v = getCommonMesh({rho_v, rho_U_v, rho_E_v});
  if (not mesh_v) {
    throw NormalError("discrete functions are not defined on the same mesh");
  }

  if (not checkDiscretizationType({rho_v, rho_U_v, rho_E_v}, DiscreteFunctionType::P0)) {
    throw NormalError("GKS solver expects P0 functions");
  }

  return std::visit(
    [&](auto&& p_mesh)
      -> std::tuple<std::shared_ptr<const DiscreteFunctionVariant>, std::shared_ptr<const DiscreteFunctionVariant>,
                    std::shared_ptr<const DiscreteFunctionVariant>> {
      using MeshType = std::decay_t<decltype(*p_mesh)>;
      if constexpr (is_polygonal_mesh_v<MeshType>) {
        if constexpr (MeshType::Dimension == 1) {
          GKS<MeshType> gks;
          return gks.solve(p_mesh, rho_v, rho_U_v, rho_E_v, tau, delta, dt);

        } else {
          throw NormalError("dimension not treated");
        }

      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    mesh_v->variant());
}
