#include <scheme/FluxingAdvectionSolver.hpp>

#include <language/utils/EvaluateAtPoints.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/IMesh.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshFaceBoundary.hpp>
#include <mesh/SubItemValuePerItem.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
template <size_t Dimension>
class FluxingAdvectionSolver
{
 private:
  using Rd = TinyVector<Dimension>;

  using MeshType     = Mesh<Connectivity<Dimension>>;
  using MeshDataType = MeshData<Dimension>;

  const std::shared_ptr<const MeshType> m_old_mesh;
  const std::shared_ptr<const MeshType> m_new_mesh;

 public:
  // CellValue<double>
  // compute_PFnp1(const DiscreteFunctionP0<Dimension, const double> F, const double& dt, const double& dx)
  // {
  //   CellValue<double> PFnp1{m_mesh.connectivity()};

  //   DiscreteFunctionP0<Dimension, double> deltaF  = compute_delta2Fn(F);
  //   DiscreteFunctionP0<Dimension, double> deltaF0 = compute_delta2Fn(m_Fn);

  //   for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
  //     PFnp1[cell_id] = m_Fn[cell_id][0] + m_Fn[cell_id][1] -
  //                      (0.5 * dt / dx) * m_lambda * (deltaF[cell_id][0] - deltaF[cell_id][1]) -
  //                      (0.5 * dt / dx) * m_lambda * (deltaF0[cell_id][0] - deltaF0[cell_id][1]);
  //   }
  //   return PFnp1;
  // }

  // DiscreteFunctionP0<Dimension, double>
  // apply(const double& dt, const double& eps)
  // {
  //   const DiscreteFunctionP0<Dimension, const double>& F0 = m_Fn;
  //   DiscreteFunctionP0<Dimension, double> Fnp1            = copy(F0);
  //   DiscreteFunctionP0<Dimension, double> deltaFn         = compute_delta2Fn(F0);

  //   for (size_t p = 0; p < 2; ++p) {
  //     CellId first_cell_id = 0;
  //     const double dx      = m_dx_table[first_cell_id];

  //     DiscreteFunctionP0<Dimension, double> deltaFnp1 = compute_delta2Fn(Fnp1);

  //     const CellValue<const double> PFnp1  = compute_PFnp1(Fnp1, dt, dx);
  //     const CellValue<const double> APFnp1 = getA(PFnp1);
  //     const CellArray<const double> MPFnp1 = compute_M(PFnp1, APFnp1);
  //     const CellValue<const double> PFn    = compute_PFn(F0);
  //     const CellValue<const double> APFn   = getA(PFn);
  //     const CellArray<const double> MPFn   = compute_M(PFn, APFn);

  //     for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
  //       Fnp1[cell_id][0] = 1. / (1 + 0.5 * dt / eps) *
  //                          ((0.5 * dt / eps) * MPFnp1[cell_id][0] + F0[cell_id][0] -
  //                           (0.5 * dt / dx) * m_lambda * (deltaFnp1[cell_id][0] + deltaFn[cell_id][0]) +
  //                           (0.5 * dt / eps) * (MPFn[cell_id][0] - F0[cell_id][0]));
  //       Fnp1[cell_id][1] = 1. / (1 + 0.5 * dt / eps) *
  //                          ((0.5 * dt / eps) * MPFnp1[cell_id][1] + F0[cell_id][1] +
  //                           (0.5 * dt / dx) * m_lambda * (deltaFnp1[cell_id][1] + deltaFn[cell_id][1]) +
  //                           (0.5 * dt / eps) * (MPFn[cell_id][1] - F0[cell_id][1]));
  //     }
  //   }

  //   return Fnp1;
  // }

  FaceValue<double> computeFluxVolume() const;

  FluxingAdvectionSolver(const std::shared_ptr<const MeshType> old_mesh, const std::shared_ptr<const MeshType> new_mesh)
    : m_old_mesh{old_mesh}, m_new_mesh{new_mesh}
  {}

  ~FluxingAdvectionSolver() = default;
};

template <>
FaceValue<double>
FluxingAdvectionSolver<1>::computeFluxVolume() const
{
  throw NotImplementedError("Viens");
}

template <>
FaceValue<double>
FluxingAdvectionSolver<2>::computeFluxVolume() const
{
  if (m_new_mesh->shared_connectivity() != m_old_mesh->shared_connectivity()) {
    throw NormalError("Old and new meshes must share the same connectivity");
  }
  // std::cout << " CARRE "
  // << "\n";
  // MeshDataType& old_mesh_data    = MeshDataManager::instance().getMeshData(*m_old_mesh);
  // MeshDataType& new_mesh_data    = MeshDataManager::instance().getMeshData(*m_new_mesh);
  const auto face_to_node_matrix = m_old_mesh->connectivity().faceToNodeMatrix();
  FaceValue<double> flux_volume(m_new_mesh->connectivity());
  parallel_for(
    m_new_mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
      const auto face_to_node = face_to_node_matrix[face_id];
      const Rd x0             = m_old_mesh->xr()[face_to_node[0]];
      const Rd x1             = m_old_mesh->xr()[face_to_node[1]];
      const Rd x2             = m_new_mesh->xr()[face_to_node[1]];
      const Rd x3             = m_new_mesh->xr()[face_to_node[0]];
      TinyMatrix<2> M(x2[0] - x0[0], x3[0] - x1[0], x2[1] - x0[1], x3[1] - x1[1]);
      flux_volume[face_id] = 0.5 * det(M);
      // std::cout << " x1 " << x1 << " x0 " << x0 << " x3 " << x3 << " x2 " << x2 << " flux volume "
      //           << flux_volume[face_id] << "\n";
    });
  return flux_volume;
}

template <>
FaceValue<double>
FluxingAdvectionSolver<3>::computeFluxVolume() const
{
  throw NotImplementedError("ViensViensViens");
}

template <typename MeshType, typename DataType>
auto
calculateRemapCycles(const std::shared_ptr<const MeshType>& old_mesh,
                     [[maybe_unused]] const FaceValue<DataType>& fluxing_volumes)
{
  constexpr size_t Dimension                               = MeshType::Dimension;
  const FaceValuePerCell<const bool> cell_face_is_reversed = old_mesh->connectivity().cellFaceIsReversed();
  const auto cell_to_face_matrix                           = old_mesh->connectivity().cellToFaceMatrix();
  const CellValue<double> total_negative_flux(old_mesh->connectivity());
  total_negative_flux.fill(0);
  parallel_for(
    old_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
      const auto& cell_to_face = cell_to_face_matrix[cell_id];
      for (size_t i_face = 0; i_face < cell_to_face.size(); ++i_face) {
        FaceId face_id = cell_to_face[i_face];
        double flux    = fluxing_volumes[face_id];
        if (!cell_face_is_reversed(cell_id, i_face)) {
          flux = -flux;
        }
        if (flux < 0) {
          total_negative_flux[cell_id] += flux;
        }
      }
      // std::cout << " cell_id " << cell_id << " total_negative_flux " << total_negative_flux[cell_id] << "\n";
    });
  MeshData<Dimension>& mesh_data   = MeshDataManager::instance().getMeshData(*old_mesh);
  const CellValue<const double> Vj = mesh_data.Vj();
  const CellValue<size_t> ratio(old_mesh->connectivity());
  parallel_for(
    old_mesh->numberOfCells(),
    PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = std::ceil(abs(total_negative_flux[cell_id]) / Vj[cell_id]); });
  size_t number_of_cycle = max(ratio);
  std::cout << " number_of_cycle " << number_of_cycle << "\n";
  return number_of_cycle;
}

template <typename MeshType, typename DataType>
auto
remapUsingFluxing(const std::shared_ptr<const MeshType>& old_mesh,
                  const std::shared_ptr<const MeshType>& new_mesh,
                  const FaceValue<double>& fluxing_volumes,
                  const size_t num,
                  const DiscreteFunctionP0<MeshType::Dimension, const DataType>& old_q)
{
  constexpr size_t Dimension = MeshType::Dimension;
  //  const Connectivity<Dimension>& connectivity = new_mesh->connectivity();
  const FaceValuePerCell<const bool> cell_face_is_reversed = new_mesh->connectivity().cellFaceIsReversed();
  DiscreteFunctionP0<Dimension, DataType> new_q(new_mesh, copy(old_q.cellValues()));
  DiscreteFunctionP0<Dimension, DataType> previous_q(new_mesh, copy(old_q.cellValues()));
  const auto cell_to_face_matrix      = new_mesh->connectivity().cellToFaceMatrix();
  const auto face_to_cell_matrix      = new_mesh->connectivity().faceToCellMatrix();
  MeshData<Dimension>& old_mesh_data  = MeshDataManager::instance().getMeshData(*old_mesh);
  const CellValue<const double> oldVj = old_mesh_data.Vj();
  const CellValue<double> Vjstep(new_mesh->connectivity());
  parallel_for(
    new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
      Vjstep[cell_id] = oldVj[cell_id];
      new_q[cell_id] *= oldVj[cell_id];
    });
  for (size_t jstep = 0; jstep < num; ++jstep) {
    // std::cout << " step " << jstep << "\n";
    parallel_for(
      new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
        const auto& cell_to_face = cell_to_face_matrix[cell_id];
        for (size_t i_face = 0; i_face < cell_to_face.size(); ++i_face) {
          FaceId face_id = cell_to_face[i_face];
          double flux    = fluxing_volumes[face_id];
          if (!cell_face_is_reversed(cell_id, i_face)) {
            flux = -flux;
          }
          const auto& face_to_cell = face_to_cell_matrix[face_id];
          if (face_to_cell.size() == 1) {
            continue;
          }
          CellId other_cell_id = face_to_cell[0];
          if (other_cell_id == cell_id) {
            other_cell_id = face_to_cell[1];
          }
          DataType fluxed_q = previous_q[cell_id];
          if (flux > 0) {
            fluxed_q = previous_q[other_cell_id];
          }
          Vjstep[cell_id] += flux / num;
          fluxed_q *= flux / num;
          new_q[cell_id] += fluxed_q;
        }
        // std::cout << " old q " << old_q[cell_id] << " new q " << new_q[cell_id] << "\n";
      });
    parallel_for(
      new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
        previous_q[cell_id] = 1 / Vjstep[cell_id] * new_q[cell_id];
        //     std::cout << " old q " << old_q[cell_id] << " new q " << previous_q[cell_id] << "\n";
        //     std::cout << " old vj " << oldVj[cell_id] << " new Vj " << Vjstep[cell_id] << "\n";
      });
  }

  MeshData<Dimension>& new_mesh_data  = MeshDataManager::instance().getMeshData(*new_mesh);
  const CellValue<const double> newVj = new_mesh_data.Vj();
  parallel_for(
    new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { new_q[cell_id] = 1 / newVj[cell_id] * new_q[cell_id]; });
  // for (CellId cell_id = 0; cell_id < new_mesh->numberOfCells(); ++cell_id) {
  //   if (abs(newVj[cell_id] - Vjstep[cell_id]) > 1e-15) {
  //     std::cout << " cell " << cell_id << " newVj " << newVj[cell_id] << " Vjstep " << Vjstep[cell_id] << " diff "
  //               << abs(newVj[cell_id] - Vjstep[cell_id]) << "\n";
  //   }
  // }
  return new_q;
}

template <typename MeshType, typename DataType>
auto
remapUsingFluxing([[maybe_unused]] const std::shared_ptr<const MeshType>& old_mesh,
                  const std::shared_ptr<const MeshType>& new_mesh,
                  [[maybe_unused]] const FaceValue<double>& fluxing_volumes,
                  [[maybe_unused]] const size_t num,
                  const DiscreteFunctionP0Vector<MeshType::Dimension, const DataType>& old_q)
{
  constexpr size_t Dimension = MeshType::Dimension;
  //  const Connectivity<Dimension>& connectivity = new_mesh->connectivity();

  DiscreteFunctionP0Vector<Dimension, DataType> new_q(new_mesh, copy(old_q.cellArrays()));

  throw NotImplementedError("DiscreteFunctionP0Vector");

  return new_q;
}

std::vector<std::shared_ptr<const DiscreteFunctionVariant>>
FluxingAdvectionSolverHandler(const std::shared_ptr<const IMesh> new_mesh,
                              const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& remapped_variables)
{
  const std::shared_ptr<const IMesh> old_mesh = getCommonMesh(remapped_variables);

  if (not checkDiscretizationType({remapped_variables}, DiscreteFunctionType::P0)) {
    throw NormalError("acoustic solver expects P0 functions");
  }

  switch (old_mesh->dimension()) {
  case 1: {
    constexpr size_t Dimension = 1;
    using MeshType             = Mesh<Connectivity<Dimension>>;

    const std::shared_ptr<const MeshType> old_mesh0 = std::dynamic_pointer_cast<const MeshType>(old_mesh);

    const std::shared_ptr<const MeshType> new_mesh0 = std::dynamic_pointer_cast<const MeshType>(new_mesh);

    FluxingAdvectionSolver<Dimension> solver(old_mesh0, new_mesh0);

    FaceValue<double> fluxing_volumes(new_mesh0->connectivity());
    fluxing_volumes.fill(0);

    std::vector<std::shared_ptr<const DiscreteFunctionVariant>> new_variables;

    // for (auto&& variable_v : remapped_variables) {
    //   std::visit(
    //     [&](auto&& variable) {
    //       using DiscreteFunctionT = std::decay_t<decltype(variable)>;
    //       if constexpr (std::is_same_v<MeshType, typename DiscreteFunctionT::MeshType>) {
    //         remapUsingFluxing(new_mesh0, fluxing_volumes, variable);
    //       }
    //     },
    //     variable_v->discreteFunction());
    // }

    return remapped_variables;   // std::make_shared<std::vector<std::shared_ptr<const
                                 // DiscreteFunctionVariant>>>(new_variables);
  }
  case 2: {
    constexpr size_t Dimension = 2;

    using MeshType = Mesh<Connectivity<Dimension>>;

    const std::shared_ptr<const MeshType> old_mesh0 = std::dynamic_pointer_cast<const MeshType>(old_mesh);

    const std::shared_ptr<const MeshType> new_mesh0 = std::dynamic_pointer_cast<const MeshType>(new_mesh);

    FluxingAdvectionSolver<Dimension> solver(old_mesh0, new_mesh0);

    FaceValue<double> fluxing_volumes(new_mesh0->connectivity());
    fluxing_volumes.fill(0);

    std::vector<std::shared_ptr<const DiscreteFunctionVariant>> new_variables;

    // for (auto&& variable_v : remapped_variables) {
    //   std::visit(
    //     [&](auto&& variable) {
    //       using DiscreteFunctionT = std::decay_t<decltype(variable)>;
    //       if constexpr (std::is_same_v<MeshType, typename DiscreteFunctionT::MeshType>) {
    //         remapUsingFluxing(new_mesh0, fluxing_volumes, variable);
    //       }
    //     },
    //     variable_v->discreteFunction());
    // }

    return remapped_variables;   // std::make_shared<std::vector<std::shared_ptr<const

    // throw NotImplementedError("Fluxing advection solver not implemented in dimension 2");
  }
  case 3: {
    throw NotImplementedError("Fluxing advection solver not implemented in dimension 3");
  }
  default: {
    throw UnexpectedError("Invalid mesh dimension");
  }
  }
}

std::shared_ptr<const DiscreteFunctionVariant>
FluxingAdvectionSolverHandler(const std::shared_ptr<const IMesh> new_mesh,
                              const std::shared_ptr<const DiscreteFunctionVariant>& remapped_variable)
{
  const std::shared_ptr<const IMesh> old_mesh = getCommonMesh({remapped_variable});

  if (not checkDiscretizationType({remapped_variable}, DiscreteFunctionType::P0)) {
    throw NormalError("acoustic solver expects P0 functions");
  }

  switch (old_mesh->dimension()) {
  case 1: {
    throw NormalError("Not yet implemented in 1d");
  }
  case 2: {
    constexpr size_t Dimension = 2;
    using MeshType             = Mesh<Connectivity<Dimension>>;

    const std::shared_ptr<const MeshType> old_mesh0 = std::dynamic_pointer_cast<const MeshType>(old_mesh);

    const std::shared_ptr<const MeshType> new_mesh0 = std::dynamic_pointer_cast<const MeshType>(new_mesh);

    FluxingAdvectionSolver<Dimension> solver(old_mesh0, new_mesh0);

    FaceValue<double> fluxing_volumes = solver.computeFluxVolume();
    size_t number_of_cycles           = calculateRemapCycles(old_mesh0, fluxing_volumes);

    DiscreteFunctionVariant new_variable = std::visit(
      [&](auto&& variable) -> DiscreteFunctionVariant {
        using DiscreteFunctionT = std::decay_t<decltype(variable)>;
        if constexpr (std::is_same_v<MeshType, typename DiscreteFunctionT::MeshType>) {
          return remapUsingFluxing(old_mesh0, new_mesh0, fluxing_volumes, number_of_cycles, variable);
        } else {
          throw UnexpectedError("incompatible mesh types");
        }
      },
      remapped_variable->discreteFunction());

    return std::make_shared<DiscreteFunctionVariant>(new_variable);
  }
  case 3: {
    throw NotImplementedError("Fluxing advection solver not implemented in dimension 3");
  }
  default: {
    throw UnexpectedError("Invalid mesh dimension");
  }
  }
}