#include <scheme/FluxingAdvectionSolver.hpp>

#include <language/utils/EvaluateAtPoints.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/IMesh.hpp>
#include <mesh/ItemArrayUtils.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshFaceBoundary.hpp>
#include <mesh/SubItemValuePerItem.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>

template <size_t Dimension>
class FluxingAdvectionSolver
{
 private:
  using Rd = TinyVector<Dimension>;

  using MeshType     = Mesh<Connectivity<Dimension>>;
  using MeshDataType = MeshData<Dimension>;

  const std::shared_ptr<const MeshType> m_old_mesh;
  const std::shared_ptr<const MeshType> m_new_mesh;

  using RemapVariant = std::variant<CellValue<double>,
                                    CellValue<TinyVector<1>>,
                                    CellValue<TinyVector<2>>,
                                    CellValue<TinyVector<3>>,
                                    CellValue<TinyMatrix<1>>,
                                    CellValue<TinyMatrix<2>>,
                                    CellValue<TinyMatrix<3>>,

                                    CellArray<double>>;

  std::vector<RemapVariant> m_remapped_list;

  FaceValue<const CellId> m_donnor_cell;
  FaceValue<const double> m_cycle_fluxing_volume;
  size_t m_number_of_cycles;

  FaceValue<double> _computeAlgebraicFluxingVolume();
  void _computeDonorCells(FaceValue<const double> algebraic_fluxing_volumes);
  FaceValue<double> _computeFluxingVolume(FaceValue<double> algebraic_fluxing_volumes);
  void _computeCycleNumber(FaceValue<double> fluxing_volumes);
  void _computeGeometricalData();

  template <typename DataType>
  void
  _storeValues(const DiscreteFunctionP0<Dimension, const DataType>& old_q)
  {
    m_remapped_list.emplace_back(copy(old_q.cellValues()));
  }

  template <typename DataType>
  void
  _storeValues(const DiscreteFunctionP0Vector<Dimension, const DataType>& old_q)
  {
    m_remapped_list.emplace_back(copy(old_q.cellArrays()));
  }

  template <typename CellDataType>
  void _remapOne(const CellValue<const double>& step_Vj, CellDataType& old_q);

  void _remapAllQuantities();

 public:
  std::vector<std::shared_ptr<const DiscreteFunctionVariant>>   //
  remap(const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& quantities);

  FluxingAdvectionSolver(const std::shared_ptr<const IMesh> i_old_mesh, const std::shared_ptr<const IMesh> i_new_mesh)
    : m_old_mesh{std::dynamic_pointer_cast<const MeshType>(i_old_mesh)},
      m_new_mesh{std::dynamic_pointer_cast<const MeshType>(i_new_mesh)}
  {
    if ((m_old_mesh.use_count() == 0) or (m_new_mesh.use_count() == 0)) {
      throw NormalError("old and new meshes must be of same type");
    }

    if (m_new_mesh->shared_connectivity() != m_old_mesh->shared_connectivity()) {
      throw NormalError("old and new meshes must share the same connectivity");
    }

    this->_computeGeometricalData();
  }

  ~FluxingAdvectionSolver() = default;
};

template <size_t Dimension>
void
FluxingAdvectionSolver<Dimension>::_computeDonorCells(FaceValue<const double> algebraic_fluxing_volumes)
{
  m_donnor_cell = [&] {
    const FaceValuePerCell<const bool> cell_face_is_reversed = m_new_mesh->connectivity().cellFaceIsReversed();
    const auto face_to_cell_matrix                           = m_new_mesh->connectivity().faceToCellMatrix();

    const auto face_local_number_in_their_cells = m_new_mesh->connectivity().faceLocalNumbersInTheirCells();

    FaceValue<CellId> donnor_cell{m_old_mesh->connectivity()};
    parallel_for(
      m_new_mesh->numberOfFaces(), PUGS_LAMBDA(const FaceId face_id) {
        const auto& face_to_cell = face_to_cell_matrix[face_id];
        if (face_to_cell.size() == 1) {
          donnor_cell[face_id] = face_to_cell[0];
        } else {
          const CellId cell_id        = face_to_cell[0];
          const size_t i_face_in_cell = face_local_number_in_their_cells[face_id][0];
          if (cell_face_is_reversed[cell_id][i_face_in_cell] xor (algebraic_fluxing_volumes[face_id] <= 0)) {
            donnor_cell[face_id] = cell_id;
          } else {
            donnor_cell[face_id] = face_to_cell[1];
          }
        }
      });

    return donnor_cell;
  }();
}

template <>
void
FluxingAdvectionSolver<1>::_computeDonorCells(FaceValue<const double> algebraic_fluxing_volumes)
{
  m_donnor_cell = [&] {
    const auto face_to_cell_matrix = m_new_mesh->connectivity().faceToCellMatrix();
    const auto cell_to_face_matrix = m_new_mesh->connectivity().cellToFaceMatrix();

    FaceValue<CellId> donnor_cell{m_old_mesh->connectivity()};
    parallel_for(
      m_new_mesh->numberOfFaces(), PUGS_LAMBDA(const FaceId face_id) {
        const auto& face_to_cell = face_to_cell_matrix[face_id];
        if (face_to_cell.size() == 1) {
          donnor_cell[face_id] = face_to_cell[0];
        } else {
          const CellId cell_id = face_to_cell[0];
          if ((algebraic_fluxing_volumes[face_id] <= 0) xor (cell_to_face_matrix[cell_id][0] == face_id)) {
            donnor_cell[face_id] = cell_id;
          } else {
            donnor_cell[face_id] = face_to_cell[1];
          }
        }
      });

    return donnor_cell;
  }();
}

template <>
FaceValue<double>
FluxingAdvectionSolver<1>::_computeAlgebraicFluxingVolume()
{
  Array<double> fluxing_volumes{m_new_mesh->numberOfNodes()};
  NodeValue<double> nodal_fluxing_volume(m_new_mesh->connectivity(), fluxing_volumes);
  auto old_xr = m_old_mesh->xr();
  auto new_xr = m_new_mesh->xr();

  parallel_for(
    m_new_mesh->numberOfNodes(),
    PUGS_LAMBDA(NodeId node_id) { nodal_fluxing_volume[node_id] = new_xr[node_id][0] - old_xr[node_id][0]; });

  FaceValue<double> algebraic_fluxing_volumes(m_new_mesh->connectivity(), fluxing_volumes);

  synchronize(algebraic_fluxing_volumes);
  return algebraic_fluxing_volumes;
}

template <>
FaceValue<double>
FluxingAdvectionSolver<2>::_computeAlgebraicFluxingVolume()
{
  const auto face_to_node_matrix = m_old_mesh->connectivity().faceToNodeMatrix();
  FaceValue<double> algebraic_fluxing_volume(m_new_mesh->connectivity());
  auto old_xr = m_old_mesh->xr();
  auto new_xr = m_new_mesh->xr();
  parallel_for(
    m_new_mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
      const auto& face_to_node = face_to_node_matrix[face_id];

      const Rd& x0 = old_xr[face_to_node[0]];
      const Rd& x1 = old_xr[face_to_node[1]];
      const Rd& x2 = new_xr[face_to_node[1]];
      const Rd& x3 = new_xr[face_to_node[0]];

      TinyMatrix<2> M(x3[0] - x1[0], x2[0] - x0[0],   //
                      x3[1] - x1[1], x2[1] - x0[1]);

      algebraic_fluxing_volume[face_id] = 0.5 * det(M);
    });

  synchronize(algebraic_fluxing_volume);
  return algebraic_fluxing_volume;
}

template <>
FaceValue<double>
FluxingAdvectionSolver<3>::_computeAlgebraicFluxingVolume()
{
  const auto face_to_node_matrix = m_old_mesh->connectivity().faceToNodeMatrix();
  FaceValue<double> algebraic_fluxing_volume(m_new_mesh->connectivity());
  auto old_xr = m_old_mesh->xr();
  auto new_xr = m_new_mesh->xr();
  parallel_for(
    m_new_mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
      const auto& face_to_node = face_to_node_matrix[face_id];
      if (face_to_node.size() == 4) {
        const Rd& x0 = old_xr[face_to_node[0]];
        const Rd& x1 = old_xr[face_to_node[1]];
        const Rd& x2 = old_xr[face_to_node[2]];
        const Rd& x3 = old_xr[face_to_node[3]];

        const Rd& x4 = new_xr[face_to_node[0]];
        const Rd& x5 = new_xr[face_to_node[1]];
        const Rd& x6 = new_xr[face_to_node[2]];
        const Rd& x7 = new_xr[face_to_node[3]];

        const Rd& a1 = x6 - x1;
        const Rd& a2 = x6 - x3;
        const Rd& a3 = x6 - x4;

        const Rd& b1 = x7 - x0;
        const Rd& b2 = x5 - x0;
        const Rd& b3 = x2 - x0;

        TinyMatrix<3> M1(a1 + b1, a2, a3);
        TinyMatrix<3> M2(b1, a2 + b2, a3);
        TinyMatrix<3> M3(a1, b2, a3 + b3);

        algebraic_fluxing_volume[face_id] = (det(M1) + det(M2) + det(M3)) / 12;
      } else if (face_to_node.size() == 3) {
        const Rd& x0 = old_xr[face_to_node[0]];
        const Rd& x1 = old_xr[face_to_node[1]];
        const Rd& x2 = old_xr[face_to_node[2]];

        const Rd& x3 = new_xr[face_to_node[0]];
        const Rd& x4 = new_xr[face_to_node[1]];
        const Rd& x5 = new_xr[face_to_node[2]];

        const Rd& a1 = x5 - x1;
        const Rd& a2 = x5 - x2;
        const Rd& a3 = x5 - x3;

        const Rd& b1 = x5 - x0;
        const Rd& b2 = x4 - x0;
        const Rd& b3 = x2 - x0;

        TinyMatrix<3> M1(a1 + b1, a2, a3);
        TinyMatrix<3> M2(b1, a2 + b2, a3);
        TinyMatrix<3> M3(a1, b2, a3 + b3);

        algebraic_fluxing_volume[face_id] = (det(M1) + det(M2) + det(M3)) / 12;
      } else {
        throw NotImplementedError("Not implemented for non quad faces");
      }
    });

  return algebraic_fluxing_volume;
}

template <size_t Dimension>
FaceValue<double>
FluxingAdvectionSolver<Dimension>::_computeFluxingVolume(FaceValue<double> algebraic_fluxing_volumes)
{
  Assert(m_donnor_cell.isBuilt());
  // Now that donnor cells are clearly defined, we consider the
  // non-algebraic volumes of fluxing
  parallel_for(
    algebraic_fluxing_volumes.numberOfItems(), PUGS_LAMBDA(const FaceId face_id) {
      algebraic_fluxing_volumes[face_id] = std::abs(algebraic_fluxing_volumes[face_id]);
    });

  return algebraic_fluxing_volumes;
}

template <size_t Dimension>
void
FluxingAdvectionSolver<Dimension>::_computeCycleNumber(FaceValue<double> fluxing_volumes)
{
  const auto cell_to_face_matrix = m_old_mesh->connectivity().cellToFaceMatrix();

  const CellValue<double> total_negative_flux(m_old_mesh->connectivity());
  total_negative_flux.fill(0);

  parallel_for(
    m_old_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
      const auto& cell_to_face = cell_to_face_matrix[cell_id];
      for (size_t i_face = 0; i_face < cell_to_face.size(); ++i_face) {
        FaceId face_id = cell_to_face[i_face];
        if (cell_id == m_donnor_cell[face_id]) {
          total_negative_flux[cell_id] += fluxing_volumes[face_id];
        }
      }
    });

  MeshData<Dimension>& mesh_data   = MeshDataManager::instance().getMeshData(*m_old_mesh);
  const CellValue<const double> Vj = mesh_data.Vj();
  CellValue<size_t> ratio(m_old_mesh->connectivity());

  parallel_for(
    m_old_mesh->numberOfCells(),
    PUGS_LAMBDA(CellId cell_id) { ratio[cell_id] = std::ceil(total_negative_flux[cell_id] / Vj[cell_id]); });
  synchronize(ratio);

  size_t number_of_cycles = max(ratio);

  if (number_of_cycles > 1) {
    const double cycle_ratio = 1. / number_of_cycles;

    parallel_for(
      fluxing_volumes.numberOfItems(), PUGS_LAMBDA(const FaceId face_id) { fluxing_volumes[face_id] *= cycle_ratio; });
  }

  m_number_of_cycles     = number_of_cycles;
  m_cycle_fluxing_volume = fluxing_volumes;
}

template <size_t Dimension>
void
FluxingAdvectionSolver<Dimension>::_computeGeometricalData()
{
  auto fluxing_volumes = this->_computeAlgebraicFluxingVolume();
  this->_computeDonorCells(fluxing_volumes);
  fluxing_volumes = this->_computeFluxingVolume(fluxing_volumes);
  this->_computeCycleNumber(fluxing_volumes);
}

template <size_t Dimension>
template <typename CellDataType>
void
FluxingAdvectionSolver<Dimension>::_remapOne(const CellValue<const double>& step_Vj, CellDataType& old_q)
{
  static_assert(is_item_value_v<CellDataType> or is_item_array_v<CellDataType>, "invalid data type");

  const auto cell_to_face_matrix = m_new_mesh->connectivity().cellToFaceMatrix();
  const auto face_to_cell_matrix = m_new_mesh->connectivity().faceToCellMatrix();

  auto new_q = copy(old_q);

  if constexpr (is_item_value_v<CellDataType>) {
    parallel_for(
      m_new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { new_q[cell_id] *= step_Vj[cell_id]; });
  } else if constexpr (is_item_array_v<CellDataType>) {
    parallel_for(
      m_new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
        auto new_array      = new_q[cell_id];
        const double volume = step_Vj[cell_id];

        for (size_t i = 0; i < new_array.size(); ++i) {
          new_array[i] *= volume;
        }
      });
  }

  parallel_for(
    m_new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
      const auto& cell_to_face = cell_to_face_matrix[cell_id];
      for (size_t i_face = 0; i_face < cell_to_face.size(); ++i_face) {
        const FaceId face_id        = cell_to_face[i_face];
        const double fluxing_volume = m_cycle_fluxing_volume[face_id];

        const auto& face_to_cell = face_to_cell_matrix[face_id];

        if (face_to_cell.size() == 1) {
          continue;
        }

        CellId donnor_id = m_donnor_cell[face_id];

        if constexpr (is_item_value_v<CellDataType>) {
          auto fluxed_q = old_q[donnor_id];
          fluxed_q *= ((donnor_id == cell_id) ? -1 : 1) * fluxing_volume;

          new_q[cell_id] += fluxed_q;
        } else if constexpr (is_item_array_v<CellDataType>) {
          const double sign   = ((donnor_id == cell_id) ? -1 : 1);
          auto old_cell_array = old_q[donnor_id];
          auto new_cell_array = new_q[cell_id];
          for (size_t i = 0; i < new_cell_array.size(); ++i) {
            new_cell_array[i] += (sign * fluxing_volume) * old_cell_array[i];
          }
        }
      }
    });

  synchronize(new_q);
  old_q = new_q;
}

template <size_t Dimension>
void
FluxingAdvectionSolver<Dimension>::_remapAllQuantities()
{
  const auto cell_to_face_matrix              = m_new_mesh->connectivity().cellToFaceMatrix();
  const auto face_local_number_in_their_cells = m_new_mesh->connectivity().faceLocalNumbersInTheirCells();

  MeshData<Dimension>& old_mesh_data = MeshDataManager::instance().getMeshData(*m_old_mesh);

  const CellValue<const double> old_Vj = old_mesh_data.Vj();
  const CellValue<double> step_Vj      = copy(old_Vj);

  for (size_t jstep = 0; jstep < m_number_of_cycles; ++jstep) {
    for (auto& remapped_q : m_remapped_list) {
      std::visit([&](auto&& old_q) { this->_remapOne(step_Vj, old_q); }, remapped_q);
    }

    parallel_for(
      m_new_mesh->numberOfCells(), PUGS_LAMBDA(const CellId cell_id) {
        const auto& cell_to_face = cell_to_face_matrix[cell_id];
        for (size_t i_face = 0; i_face < cell_to_face.size(); ++i_face) {
          const FaceId face_id = cell_to_face[i_face];
          CellId donnor_id     = m_donnor_cell[face_id];

          double flux = ((donnor_id == cell_id) ? -1 : 1) * m_cycle_fluxing_volume[face_id];
          step_Vj[cell_id] += flux;
        }
      });

    synchronize(step_Vj);

    CellValue<double> inv_Vj(m_old_mesh->connectivity());
    parallel_for(
      m_new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { inv_Vj[cell_id] = 1 / step_Vj[cell_id]; });

    for (auto& remapped_q : m_remapped_list) {
      std::visit(
        [&](auto&& new_q) {
          using CellDataType = std::decay_t<decltype(new_q)>;
          static_assert(is_item_value_v<CellDataType> or is_item_array_v<CellDataType>, "invalid data type");

          if constexpr (is_item_value_v<CellDataType>) {
            parallel_for(
              m_new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { new_q[cell_id] *= inv_Vj[cell_id]; });
          } else if constexpr (is_item_array_v<CellDataType>) {
            parallel_for(
              m_new_mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
                auto array              = new_q[cell_id];
                const double inv_volume = inv_Vj[cell_id];

                for (size_t i = 0; i < array.size(); ++i) {
                  array[i] *= inv_volume;
                }
              });
          }
        },
        remapped_q);
    }
  }
}

template <size_t Dimension>
std::vector<std::shared_ptr<const DiscreteFunctionVariant>>
FluxingAdvectionSolver<Dimension>::remap(
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& quantity_list)
{
  for (auto&& variable_v : quantity_list) {
    std::visit(
      [&](auto&& variable) {
        using DiscreteFunctionT = std::decay_t<decltype(variable)>;
        if constexpr (std::is_same_v<MeshType, typename DiscreteFunctionT::MeshType>) {
          this->_storeValues(variable);
        } else {
          throw UnexpectedError("incompatible mesh types");
        }
      },
      variable_v->discreteFunction());
  }

  this->_remapAllQuantities();

  std::vector<std::shared_ptr<const DiscreteFunctionVariant>> new_variables;

  for (size_t i = 0; i < quantity_list.size(); ++i) {
    std::visit(
      [&](auto&& variable) {
        using DiscreteFunctionT = std::decay_t<decltype(variable)>;
        using DataType          = std::decay_t<typename DiscreteFunctionT::data_type>;

        if constexpr (std::is_same_v<MeshType, typename DiscreteFunctionT::MeshType>) {
          if constexpr (is_discrete_function_P0_v<DiscreteFunctionT>) {
            new_variables.push_back(std::make_shared<DiscreteFunctionVariant>(
              DiscreteFunctionT(m_new_mesh, std::get<CellValue<DataType>>(m_remapped_list[i]))));
          } else if constexpr (is_discrete_function_P0_vector_v<DiscreteFunctionT>) {
            new_variables.push_back(std::make_shared<DiscreteFunctionVariant>(
              DiscreteFunctionT(m_new_mesh, std::get<CellArray<DataType>>(m_remapped_list[i]))));
          } else {
            throw UnexpectedError("invalid discrete function type");
          }
        } else {
          throw UnexpectedError("incompatible mesh types");
        }
      },
      quantity_list[i]->discreteFunction());
  }

  return new_variables;
}

std::vector<std::shared_ptr<const DiscreteFunctionVariant>>
advectByFluxing(const std::shared_ptr<const IMesh> i_new_mesh,
                const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& remapped_variables)
{
  if (not hasSameMesh(remapped_variables)) {
    throw NormalError("remapped quantities are not defined on the same mesh");
  }

  const std::shared_ptr<const IMesh> i_old_mesh = getCommonMesh(remapped_variables);

  switch (i_old_mesh->dimension()) {
  case 1: {
    return FluxingAdvectionSolver<1>{i_old_mesh, i_new_mesh}.remap(remapped_variables);
  }
  case 2: {
    return FluxingAdvectionSolver<2>{i_old_mesh, i_new_mesh}.remap(remapped_variables);
  }
  case 3: {
    return FluxingAdvectionSolver<3>{i_old_mesh, i_new_mesh}.remap(remapped_variables);
  }
  default: {
    throw UnexpectedError("Invalid mesh dimension");
  }
  }
}