#ifndef CONNECTIVITY_TO_DIAMOND_DUAL_CONNECTIVITY_DATA_MAPPER_HPP
#define CONNECTIVITY_TO_DIAMOND_DUAL_CONNECTIVITY_DATA_MAPPER_HPP

#include <mesh/Connectivity.hpp>
#include <mesh/ItemValue.hpp>
#include <utils/Array.hpp>
#include <utils/PugsAssert.hpp>

class IConnectivityToDiamondDualConnectivityDataMapper
{
 public:
  IConnectivityToDiamondDualConnectivityDataMapper(const IConnectivityToDiamondDualConnectivityDataMapper&) = delete;
  IConnectivityToDiamondDualConnectivityDataMapper(IConnectivityToDiamondDualConnectivityDataMapper&&)      = delete;

  IConnectivityToDiamondDualConnectivityDataMapper()          = default;
  virtual ~IConnectivityToDiamondDualConnectivityDataMapper() = default;
};

template <size_t Dimension>
class ConnectivityToDiamondDualConnectivityDataMapper : public IConnectivityToDiamondDualConnectivityDataMapper
{
 private:
  const IConnectivity* m_primal_connectivity;
  const IConnectivity* m_dual_connectivity;

  using NodeIdToNodeIdMap = Array<std::pair<NodeId, NodeId>>;
  NodeIdToNodeIdMap m_primal_node_to_dual_node_map;

  using CellIdToNodeIdMap = Array<std::pair<CellId, NodeId>>;
  CellIdToNodeIdMap m_primal_cell_to_dual_node_map;

  using FaceIdToCellIdMap = Array<std::pair<FaceId, CellId>>;
  FaceIdToCellIdMap m_primal_face_to_dual_cell_map;

 public:
  template <typename OriginDataType1, typename OriginDataType2, typename DestinationDataType>
  void
  toDualNode(const NodeValue<OriginDataType1>& primal_node_value,
             const CellValue<OriginDataType2>& primal_cell_value,
             const NodeValue<DestinationDataType>& dual_node_value) const
  {
    static_assert(not std::is_const_v<DestinationDataType>, "destination data type must not be constant");
    static_assert(std::is_same_v<std::remove_const_t<OriginDataType1>, DestinationDataType>, "incompatible types");
    static_assert(std::is_same_v<std::remove_const_t<OriginDataType2>, DestinationDataType>, "incompatible types");

    Assert(m_primal_connectivity == primal_cell_value.connectivity_ptr().get());
    Assert(m_primal_connectivity == primal_node_value.connectivity_ptr().get());
    Assert(m_dual_connectivity == dual_node_value.connectivity_ptr().get());

    parallel_for(
      m_primal_node_to_dual_node_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_node_id, dual_node_id] = m_primal_node_to_dual_node_map[i];

        dual_node_value[dual_node_id] = primal_node_value[primal_node_id];
      });

    parallel_for(
      m_primal_cell_to_dual_node_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_cell_id, dual_node_id] = m_primal_cell_to_dual_node_map[i];
        dual_node_value[dual_node_id]             = primal_cell_value[primal_cell_id];
      });
  }

  template <typename OriginDataType, typename DestinationDataType1, typename DestinationDataType2>
  void
  fromDualNode(const NodeValue<OriginDataType>& dual_node_value,
               const NodeValue<DestinationDataType1>& primal_node_value,
               const CellValue<DestinationDataType2>& primal_cell_value) const
  {
    static_assert(not std::is_const_v<DestinationDataType1>, "destination data type must not be constant");
    static_assert(not std::is_const_v<DestinationDataType2>, "destination data type must not be constant");
    static_assert(std::is_same_v<std::remove_const_t<OriginDataType>, DestinationDataType1>, "incompatible types");
    static_assert(std::is_same_v<std::remove_const_t<OriginDataType>, DestinationDataType2>, "incompatible types");

    Assert(m_primal_connectivity == primal_cell_value.connectivity_ptr().get());
    Assert(m_primal_connectivity == primal_node_value.connectivity_ptr().get());
    Assert(m_dual_connectivity == dual_node_value.connectivity_ptr().get());

    parallel_for(
      m_primal_node_to_dual_node_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_node_id, dual_node_id] = m_primal_node_to_dual_node_map[i];

        primal_node_value[primal_node_id] = dual_node_value[dual_node_id];
      });

    parallel_for(
      m_primal_cell_to_dual_node_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_cell_id, dual_node_id] = m_primal_cell_to_dual_node_map[i];
        primal_cell_value[primal_cell_id]         = dual_node_value[dual_node_id];
      });
  }

  template <typename OriginDataType, typename DestinationDataType>
  void
  toDualCell(const FaceValue<OriginDataType>& primal_face_value,
             const CellValue<DestinationDataType>& dual_cell_value) const
  {
    static_assert(not std::is_const_v<DestinationDataType>, "destination data type must not be constant");
    static_assert(std::is_same_v<std::remove_const_t<OriginDataType>, DestinationDataType>, "incompatible types");

    Assert(m_primal_connectivity == primal_face_value.connectivity_ptr().get());
    Assert(m_dual_connectivity == dual_cell_value.connectivity_ptr().get());

    parallel_for(
      m_primal_face_to_dual_cell_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_face_id, dual_cell_id] = m_primal_face_to_dual_cell_map[i];

        dual_cell_value[dual_cell_id] = primal_face_value[primal_face_id];
      });

    parallel_for(
      m_primal_cell_to_dual_node_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_face_id, dual_cell_id] = m_primal_face_to_dual_cell_map[i];

        dual_cell_value[dual_cell_id] = primal_face_value[primal_face_id];
      });
  }

  template <typename OriginDataType, typename DestinationDataType>
  void
  fromDualCell(const CellValue<DestinationDataType>& dual_cell_value,
               const FaceValue<OriginDataType>& primal_face_value) const
  {
    static_assert(not std::is_const_v<DestinationDataType>, "destination data type must not be constant");
    static_assert(std::is_same_v<std::remove_const_t<OriginDataType>, DestinationDataType>, "incompatible types");

    Assert(m_primal_connectivity == primal_face_value.connectivity_ptr().get());
    Assert(m_dual_connectivity == dual_cell_value.connectivity_ptr().get());

    parallel_for(
      m_primal_face_to_dual_cell_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_face_id, dual_cell_id] = m_primal_face_to_dual_cell_map[i];

        primal_face_value[primal_face_id] = dual_cell_value[dual_cell_id];
      });

    parallel_for(
      m_primal_cell_to_dual_node_map.size(), PUGS_LAMBDA(size_t i) {
        const auto [primal_face_id, dual_cell_id] = m_primal_face_to_dual_cell_map[i];
        primal_face_value[primal_face_id]         = dual_cell_value[dual_cell_id];
      });
  }

  ConnectivityToDiamondDualConnectivityDataMapper(const Connectivity<Dimension>& primal_connectivity,
                                                  const Connectivity<Dimension>& dual_connectivity)
    : m_primal_connectivity{&primal_connectivity}, m_dual_connectivity{&dual_connectivity}
  {
    if constexpr (Dimension == 1) {
      const auto& node_to_cell_matrix = primal_connectivity.nodeToCellMatrix();

      NodeId dual_node_id            = 0;
      m_primal_node_to_dual_node_map = [&]() {
        std::vector<std::pair<NodeId, NodeId>> primal_node_to_dual_node_vector;
        for (NodeId primal_node_id = 0; primal_node_id < primal_connectivity.numberOfNodes(); ++primal_node_id) {
          if (node_to_cell_matrix[primal_node_id].size() == 1) {
            primal_node_to_dual_node_vector.push_back(std::make_pair(primal_node_id, dual_node_id++));
          }
        }
        return convert_to_array(primal_node_to_dual_node_vector);
      }();

      m_primal_cell_to_dual_node_map = [&]() {
        CellIdToNodeIdMap primal_cell_to_dual_node_map{primal_connectivity.numberOfCells()};
        for (CellId primal_cell_id = 0; primal_cell_id < primal_cell_to_dual_node_map.size(); ++primal_cell_id) {
          primal_cell_to_dual_node_map[primal_cell_id] = std::make_pair(primal_cell_id, dual_node_id++);
        }
        return primal_cell_to_dual_node_map;
      }();

    } else {
      m_primal_node_to_dual_node_map = [&]() {
        NodeIdToNodeIdMap primal_node_to_dual_node_map{primal_connectivity.numberOfNodes()};
        for (NodeId primal_node_id = 0; primal_node_id < primal_node_to_dual_node_map.size(); ++primal_node_id) {
          const NodeId dual_node_id = primal_node_id;

          primal_node_to_dual_node_map[primal_node_id] = std::make_pair(primal_node_id, dual_node_id);
        }
        return primal_node_to_dual_node_map;
      }();

      m_primal_cell_to_dual_node_map = [&]() {
        CellIdToNodeIdMap primal_cell_to_dual_node_map{primal_connectivity.numberOfCells()};
        NodeId dual_node_id = m_primal_node_to_dual_node_map.size();
        for (CellId primal_cell_id = 0; primal_cell_id < primal_cell_to_dual_node_map.size(); ++primal_cell_id) {
          primal_cell_to_dual_node_map[primal_cell_id] = std::make_pair(primal_cell_id, dual_node_id++);
        }
        return primal_cell_to_dual_node_map;
      }();
    }

    m_primal_face_to_dual_cell_map = [&]() {
      FaceIdToCellIdMap primal_face_to_dual_cell_map{primal_connectivity.numberOfFaces()};
      for (size_t id = 0; id < primal_face_to_dual_cell_map.size(); ++id) {
        const CellId dual_cell_id   = id;
        const FaceId primal_face_id = id;

        primal_face_to_dual_cell_map[id] = std::make_pair(primal_face_id, dual_cell_id);
      }
      return primal_face_to_dual_cell_map;
    }();
  }
};

#endif   // CONNECTIVITY_TO_DIAMOND_DUAL_CONNECTIVITY_DATA_MAPPER_HPP