#ifndef CONNECTIVITY_HPP
#define CONNECTIVITY_HPP

#include <PastisAssert.hpp>
#include <TinyVector.hpp>

#include <Kokkos_Core.hpp>

#include <ConnectivityMatrix.hpp>
#include <ConnectivityComputer.hpp>
#include <SubItemValuePerItem.hpp>

#include <vector>
#include <unordered_map>
#include <algorithm>

#include <RefId.hpp>
#include <TypeOfItem.hpp>
#include <RefNodeList.hpp>
#include <RefFaceList.hpp>

#include <tuple>
#include <algorithm>

#include <IConnectivity.hpp>

template <size_t Dimension>
class Connectivity;

template <size_t Dimension>
class ConnectivityFace;

template<>
class ConnectivityFace<1>
{
 public:
  friend struct Hash;
  struct Hash
  {
    size_t operator()(const ConnectivityFace& f) const;
  };
};

template<>
class ConnectivityFace<2>
{
 public:
  friend struct Hash;
  struct Hash
  {
    size_t operator()(const ConnectivityFace& f) const {
      size_t hash = 0;
      hash ^= std::hash<unsigned int>()(f.m_node0_id);
      hash ^= std::hash<unsigned int>()(f.m_node1_id) >> 1;
      return hash;
    }
  };

  unsigned int m_node0_id;
  unsigned int m_node1_id;

  friend std::ostream& operator<<(std::ostream& os, const ConnectivityFace& f)
  {
    os << f.m_node0_id << ' ' << f.m_node1_id << ' ';
    return os;
  }

  KOKKOS_INLINE_FUNCTION
  bool operator==(const ConnectivityFace& f) const
  {
    return ((m_node0_id == f.m_node0_id) and
            (m_node1_id == f.m_node1_id));
  }

  KOKKOS_INLINE_FUNCTION
  bool operator<(const ConnectivityFace& f) const
  {
    return ((m_node0_id<f.m_node0_id) or
            ((m_node0_id == f.m_node0_id) and
             (m_node1_id<f.m_node1_id)));
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace& operator=(const ConnectivityFace&) = default;

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace& operator=(ConnectivityFace&&) = default;

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace(const std::vector<unsigned int>& given_node_id_list)
  {
    Assert(given_node_id_list.size()==2);
#warning rework this dirty constructor
    const auto& [min, max] = std::minmax(given_node_id_list[0], given_node_id_list[1]);
    m_node0_id = min;
    m_node1_id = max;
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace(const ConnectivityFace&) = default;

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace(ConnectivityFace&&) = default;

  KOKKOS_INLINE_FUNCTION
  ~ConnectivityFace() = default;
};

template <>
class ConnectivityFace<3>
{
 private:
  friend class Connectivity<3>;
  friend struct Hash;
  struct Hash
  {
    size_t operator()(const ConnectivityFace& f) const {
      size_t hash = 0;
      for (size_t i=0; i<f.m_node_id_list.size(); ++i) {
        hash ^= std::hash<unsigned int>()(f.m_node_id_list[i]) >> i;
      }
      return hash;
    }
  };

  bool m_reversed;
  std::vector<unsigned int> m_node_id_list;

  friend std::ostream& operator<<(std::ostream& os, const ConnectivityFace& f)
  {
    for (auto id : f.m_node_id_list) {
      std::cout << id << ' ';
    }
    return os;
  }

  KOKKOS_INLINE_FUNCTION
  const bool& reversed() const
  {
    return m_reversed;
  }

  KOKKOS_INLINE_FUNCTION
  const std::vector<unsigned int>& nodeIdList() const
  {
    return m_node_id_list;
  }

  KOKKOS_INLINE_FUNCTION
  std::vector<unsigned int> _sort(const std::vector<unsigned int>& node_list)
  {
    const auto min_id = std::min_element(node_list.begin(), node_list.end());
    const int shift = std::distance(node_list.begin(), min_id);

    std::vector<unsigned int> rotated_node_list(node_list.size());
    if (node_list[(shift+1)%node_list.size()] > node_list[(shift+node_list.size()-1)%node_list.size()]) {
      for (size_t i=0; i<node_list.size(); ++i) {
        rotated_node_list[i] = node_list[(shift+node_list.size()-i)%node_list.size()];
        m_reversed = true;
      }
    } else {
      for (size_t i=0; i<node_list.size(); ++i) {
        rotated_node_list[i] = node_list[(shift+i)%node_list.size()];
      }
    }

    return rotated_node_list;
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace(const std::vector<unsigned int>& given_node_id_list)
      : m_reversed(false),
        m_node_id_list(_sort(given_node_id_list))
  {
    ;
  }

 public:
  bool operator==(const ConnectivityFace& f) const
  {
    if (m_node_id_list.size() == f.nodeIdList().size()) {
      for (size_t j=0; j<m_node_id_list.size(); ++j) {
        if (m_node_id_list[j] != f.nodeIdList()[j]) {
          return false;
        }
      }
      return true;
    }
    return false;
  }

  KOKKOS_INLINE_FUNCTION
  bool operator<(const ConnectivityFace& f) const
  {
    const size_t min_nb_nodes = std::min(f.m_node_id_list.size(), m_node_id_list.size());
    for (size_t i=0; i<min_nb_nodes; ++i) {
      if (m_node_id_list[i] <  f.m_node_id_list[i]) return true;
      if (m_node_id_list[i] != f.m_node_id_list[i]) return false;
    }
    return m_node_id_list.size() < f.m_node_id_list.size();
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace& operator=(const ConnectivityFace&) = default;

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace& operator=(ConnectivityFace&&) = default;

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace(const ConnectivityFace&) = default;

  KOKKOS_INLINE_FUNCTION
  ConnectivityFace(ConnectivityFace&&) = default;


  KOKKOS_INLINE_FUNCTION
  ConnectivityFace() = delete;

  KOKKOS_INLINE_FUNCTION
  ~ConnectivityFace() = default;
};

template <size_t Dimension>
class Connectivity final
    : public IConnectivity
{
 private:
  constexpr static auto& itemId = ItemId<Dimension>::itemId;

 public:
  static constexpr size_t dimension = Dimension;

 public:
  KOKKOS_INLINE_FUNCTION
  ConnectivityMatrix cellToNodeMatrix() const
  {
    return m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityMatrix cellToFaceMatrix() const
  {
    return m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::face)];
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityMatrix faceToCellMatrix() const
  {
    return m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::cell)];
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityMatrix faceToNodeMatrix() const
  {
    return m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];
  }

  KOKKOS_INLINE_FUNCTION
  ConnectivityMatrix nodeToCellMatrix() const
  {
    return m_item_to_item_matrix[itemId(TypeOfItem::node)][itemId(TypeOfItem::cell)];
  }

  NodeValuePerCell<unsigned short> m_cell_to_node_local_cell;

  FaceValuePerCell<bool> m_cell_face_is_reversed;

  CellValuePerFace<unsigned short> m_face_to_cell_local_face;

  CellValuePerNode<unsigned short> m_node_to_cell_local_node;

  template <TypeOfItem SubItemType,
            TypeOfItem ItemType>
  const ConnectivityMatrix& itemToItemMatrix() const = delete;

  KOKKOS_INLINE_FUNCTION
  const ConnectivityMatrix& itemToItemMatrix(const TypeOfItem& item_type_0,
                                             const TypeOfItem& item_type_1) const final;

private:
  ConnectivityMatrix m_item_to_item_matrix[Dimension+1][Dimension+1];

  ConnectivityComputer m_connectivity_computer;

  std::vector<RefFaceList> m_ref_face_list;
  std::vector<RefNodeList> m_ref_node_list;

  Kokkos::View<double*> m_inv_cell_nb_nodes;

  Kokkos::View<const unsigned short*> m_node_nb_faces;
  Kokkos::View<const unsigned int**> m_node_faces;

  using Face = ConnectivityFace<Dimension>;

  std::unordered_map<Face, unsigned int, typename Face::Hash> m_face_number_map;

  void _computeFaceCellConnectivities();

 public:
  void addRefFaceList(const RefFaceList& ref_face_list)
  {
    m_ref_face_list.push_back(ref_face_list);
  }

  size_t numberOfRefFaceList() const
  {
    return m_ref_face_list.size();
  }

  const RefFaceList& refFaceList(const size_t& i) const
  {
    return m_ref_face_list[i];
  }

  void addRefNodeList(const RefNodeList& ref_node_list)
  {
    m_ref_node_list.push_back(ref_node_list);
  }

  size_t numberOfRefNodeList() const
  {
    return m_ref_node_list.size();
  }

  const RefNodeList& refNodeList(const size_t& i) const
  {
    return m_ref_node_list[i];
  }

  KOKKOS_INLINE_FUNCTION
  size_t numberOfNodes() const
  {
    const auto& node_to_cell_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::node)][itemId(TypeOfItem::cell)];
    return node_to_cell_matrix.numRows();
  }

  KOKKOS_INLINE_FUNCTION
  size_t numberOfFaces() const
  {
    const auto& face_to_node_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];
    return face_to_node_matrix.numRows();
  }

  KOKKOS_INLINE_FUNCTION
  size_t numberOfCells() const
  {
    const auto& cell_to_node_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];
    return cell_to_node_matrix.numRows();
  }

  const Kokkos::View<const double*> invCellNbNodes() const
  {
    return m_inv_cell_nb_nodes;
  }

  unsigned int getFaceNumber(const std::vector<unsigned int>& face_nodes) const
  {
    const Face face(face_nodes);
    auto i_face = m_face_number_map.find(face);
    if (i_face == m_face_number_map.end()) {
      std::cerr << "Face " << face << " not found!\n";
      throw std::exception();
      std::exit(0);
    }
    return i_face->second;
  }

  Connectivity(const Connectivity&) = delete;

  Connectivity(const std::vector<std::vector<unsigned int>>& cell_by_node_vector)
  {
    auto& cell_to_node_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];
    cell_to_node_matrix = cell_by_node_vector;

    Assert(this->numberOfCells()>0);

    {
      Kokkos::View<double*> inv_cell_nb_nodes("inv_cell_nb_nodes", this->numberOfCells());
      Kokkos::parallel_for(this->numberOfCells(), KOKKOS_LAMBDA(const int& j){
          const auto& cell_nodes = cell_to_node_matrix.rowConst(j);
          inv_cell_nb_nodes[j] = 1./cell_nodes.length;
        });
      m_inv_cell_nb_nodes = inv_cell_nb_nodes;
    }

    auto& node_to_cell_matrix
        = m_item_to_item_matrix[itemId(TypeOfItem::node)][itemId(TypeOfItem::cell)];

    m_connectivity_computer.computeInverseConnectivityMatrix(cell_to_node_matrix,
                                                             node_to_cell_matrix);

    m_node_to_cell_local_node = CellValuePerNode<unsigned short>(*this);

    m_connectivity_computer.computeLocalChildItemNumberInItem(cell_to_node_matrix,
                                                              node_to_cell_matrix,
                                                              m_node_to_cell_local_node);

    m_cell_to_node_local_cell = NodeValuePerCell<unsigned short>(*this);

    m_connectivity_computer.computeLocalChildItemNumberInItem(node_to_cell_matrix,
                                                              cell_to_node_matrix,
                                                              m_cell_to_node_local_cell);
    if constexpr (Dimension>1) {
      this->_computeFaceCellConnectivities();
    }
  }

  ~Connectivity()
  {
    ;
  }
};


using Connectivity3D = Connectivity<3>;

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<3>::itemToItemMatrix<TypeOfItem::cell,
                                  TypeOfItem::face>() const
{
  const auto& cell_to_face_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::face)];
  return cell_to_face_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<3>::itemToItemMatrix<TypeOfItem::cell,
                                  TypeOfItem::node>() const
{
  const auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];

  return cell_to_node_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<3>::itemToItemMatrix<TypeOfItem::face,
                                  TypeOfItem::cell>() const
{
  const auto& face_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::cell)];

  return face_to_cell_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<3>::itemToItemMatrix<TypeOfItem::face,
                                  TypeOfItem::node>() const
{
  const auto& face_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];
  return face_to_node_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<3>::itemToItemMatrix<TypeOfItem::node,
                                  TypeOfItem::cell>() const
{
  const auto& node_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::node)][itemId(TypeOfItem::cell)];

  return node_to_cell_matrix;
}


using Connectivity2D = Connectivity<2>;

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<2>::itemToItemMatrix<TypeOfItem::cell,
                                  TypeOfItem::face>() const
{
  const auto& cell_to_face_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::face)];
  return cell_to_face_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<2>::itemToItemMatrix<TypeOfItem::cell,
                                  TypeOfItem::node>() const
{
  const auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];

  return cell_to_node_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<2>::itemToItemMatrix<TypeOfItem::face,
                                  TypeOfItem::cell>() const
{
  const auto& face_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::cell)];

  return face_to_cell_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<2>::itemToItemMatrix<TypeOfItem::face,
                                  TypeOfItem::node>() const
{
  const auto& face_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];
  return face_to_node_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<2>::itemToItemMatrix<TypeOfItem::node,
                                  TypeOfItem::cell>() const
{
  const auto& node_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::node)][itemId(TypeOfItem::cell)];

  return node_to_cell_matrix;
}

using Connectivity1D = Connectivity<1>;

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<1>::itemToItemMatrix<TypeOfItem::cell,
                                  TypeOfItem::node>() const
{
  const auto& cell_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::node)];

  return cell_to_node_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<1>::itemToItemMatrix<TypeOfItem::cell,
                                  TypeOfItem::face>() const
{
  const auto& cell_to_face_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::cell)][itemId(TypeOfItem::face)];
  return cell_to_face_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<1>::itemToItemMatrix<TypeOfItem::face,
                                  TypeOfItem::cell>() const
{
  const auto& face_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::cell)];

  return face_to_cell_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<1>::itemToItemMatrix<TypeOfItem::face,
                                  TypeOfItem::node>() const
{
#warning in 1d, faces and node are the same
  const auto& face_to_node_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::face)][itemId(TypeOfItem::node)];
  return face_to_node_matrix;
}

template <>
template <>
inline const ConnectivityMatrix&
Connectivity<1>::itemToItemMatrix<TypeOfItem::node,
                                  TypeOfItem::cell>() const
{
  const auto& node_to_cell_matrix
      = m_item_to_item_matrix[itemId(TypeOfItem::node)][itemId(TypeOfItem::cell)];

  return node_to_cell_matrix;
}

template <size_t Dimension>
const ConnectivityMatrix&
Connectivity<Dimension>::
itemToItemMatrix(const TypeOfItem& item_type_0,
                 const TypeOfItem& item_type_1) const
{
  switch (item_type_0) {
    case TypeOfItem::cell: {
      switch (item_type_1) {
        case TypeOfItem::node: {
          return itemToItemMatrix<TypeOfItem::cell, TypeOfItem::node>();
        }
        case TypeOfItem::face: {
          return itemToItemMatrix<TypeOfItem::cell, TypeOfItem::face>();
        }
        default: {
          std::cerr << __FILE__ << ":" << __LINE__ << ": NIY " << int(item_type_1) << "\n";
          std::exit(1);
        }
      }
    }
    case TypeOfItem::face: {
      switch (item_type_1) {
        case TypeOfItem::cell: {
          return itemToItemMatrix<TypeOfItem::face, TypeOfItem::cell>();
        }
        case TypeOfItem::node: {
          return itemToItemMatrix<TypeOfItem::face, TypeOfItem::node>();
        }
        default: {
          std::cerr << __FILE__ << ":" << __LINE__ << ": NIY " << int(item_type_1) << "\n";
          std::exit(1);
        }
      }
    }
    case TypeOfItem::node: {
      switch (item_type_1) {
        case TypeOfItem::cell: {
          return itemToItemMatrix<TypeOfItem::node, TypeOfItem::cell>();
        }
        default: {
          std::cerr << __FILE__ << ":" << __LINE__ << ": NIY " << int(item_type_1) << "\n";
          std::exit(1);
        }
      }
    }
    default: {
      std::cerr << __FILE__ << ":" << __LINE__ << ": NIY " << int(item_type_0) << "\n";
      std::exit(1);
    }
  }
}


#endif // CONNECTIVITY_HPP