#ifndef CONNECTIVITY_3D_HPP
#define CONNECTIVITY_3D_HPP

#include <Kokkos_Core.hpp>
#include <PastisAssert.hpp>
#include <TinyVector.hpp>

#include <ConnectivityUtils.hpp>
#include <vector>
#include <map>
#include <unordered_map>
#include <algorithm>

#include <RefId.hpp>
#include <TypeOfItem.hpp>
#include <RefNodeList.hpp>
#include <RefFaceList.hpp>

#include <tuple>

class Connectivity3D
{
public:
  static constexpr size_t dimension = 3;

  ConnectivityMatrix m_cell_to_node_matrix;

  ConnectivityMatrix m_cell_to_face_matrix;
  ConnectivityMatrixShort m_cell_to_face_is_reversed_matrix;

  ConnectivityMatrix m_face_to_cell_matrix;
  ConnectivityMatrixShort m_face_to_cell_local_face_matrix;
  ConnectivityMatrix m_face_to_node_matrix;

  ConnectivityMatrix m_node_to_cell_matrix;
  ConnectivityMatrixShort m_node_to_cell_local_node_matrix;

  // Stores numbering of nodes of each cell.
  // gives an id to each node of each cell. (j,r) -> id
  //
  // This is different from m_cell_to_node_matrix which return the global id of
  // a local node in a cell
  ConnectivityMatrix m_node_id_per_cell_matrix;

  inline ConnectivityMatrix subItemIdPerItemMatrix() const
  {
    return m_node_id_per_cell_matrix;
  }

private:
  std::vector<RefFaceList> m_ref_face_list;
  std::vector<RefNodeList> m_ref_node_list;

  Kokkos::View<double*> m_inv_cell_nb_nodes;

  Kokkos::View<const unsigned short*> m_node_nb_faces;
  Kokkos::View<const unsigned int**> m_node_faces;

  class Face
  {
   public:
    friend struct Hash;
    struct Hash {
      size_t operator()(const Face& f) const
      {
        size_t hash = 0;
        for (size_t i=0; i<f.m_node_id_list.size(); ++i) {
          hash ^= std::hash<unsigned int>()(f.m_node_id_list[i]) >> i;
        }
        return hash;
      }
    };
   private:
    bool m_reversed;
    std::vector<unsigned int> m_node_id_list;

   public:
    friend std::ostream& operator<<(std::ostream& os, const Face& f)
    {
      for (auto id : f.m_node_id_list) {
        std::cout << id << ' ';
      }
      return os;
    }

    KOKKOS_INLINE_FUNCTION
    const bool& reversed() const
    {
      return m_reversed;
    }

    KOKKOS_INLINE_FUNCTION
    const std::vector<unsigned int>& nodeIdList() const
    {
      return m_node_id_list;
    }

    KOKKOS_INLINE_FUNCTION
    std::vector<unsigned int> _sort(const std::vector<unsigned int>& node_list)
    {
      const auto min_id = std::min_element(node_list.begin(), node_list.end());
      const int shift = std::distance(node_list.begin(), min_id);

      std::vector<unsigned int> rotated_node_list(node_list.size());
      if (node_list[(shift+1)%node_list.size()] > node_list[(shift+node_list.size()-1)%node_list.size()]) {
        for (size_t i=0; i<node_list.size(); ++i) {
          rotated_node_list[i] = node_list[(shift+node_list.size()-i)%node_list.size()];
          m_reversed = true;
        }
      } else {
        for (size_t i=0; i<node_list.size(); ++i) {
          rotated_node_list[i] = node_list[(shift+i)%node_list.size()];
        }
      }

      return rotated_node_list;
    }

    bool operator==(const Face& f) const
    {
      if (m_node_id_list.size() == f.nodeIdList().size()) {
        for (size_t j=0; j<m_node_id_list.size(); ++j) {
          if (m_node_id_list[j] != f.nodeIdList()[j]) {
            return false;
          }
        }
        return true;
      }
      return false;
    }

    KOKKOS_INLINE_FUNCTION
    bool operator<(const Face& f) const
    {
      const size_t min_nb_nodes = std::min(f.m_node_id_list.size(), m_node_id_list.size());
      for (size_t i=0; i<min_nb_nodes; ++i) {
        if (m_node_id_list[i] <  f.m_node_id_list[i]) return true;
        if (m_node_id_list[i] != f.m_node_id_list[i]) return false;
      }
      return m_node_id_list.size() < f.m_node_id_list.size();
    }

    KOKKOS_INLINE_FUNCTION
    Face& operator=(const Face&) = default;

    KOKKOS_INLINE_FUNCTION
    Face& operator=(Face&&) = default;

    KOKKOS_INLINE_FUNCTION
    Face(const Face&) = default;

    KOKKOS_INLINE_FUNCTION
    Face(Face&&) = default;

    KOKKOS_INLINE_FUNCTION
    Face(const std::vector<unsigned int>& given_node_id_list)
        : m_reversed(false),
          m_node_id_list(_sort(given_node_id_list))
    {
      ;
    }

    KOKKOS_INLINE_FUNCTION
    Face() = delete;

    KOKKOS_INLINE_FUNCTION
    ~Face() = default;
  };

  std::unordered_map<Face, unsigned int, Face::Hash> m_face_number_map;

  void _computeFaceCellConnectivities()
  {
    Kokkos::View<unsigned short*> cell_nb_faces("cell_nb_faces", this->numberOfCells());

    typedef std::tuple<unsigned int, unsigned short, bool> CellFaceInfo;
    std::map<Face, std::vector<CellFaceInfo>> face_cells_map;
    for (unsigned int j=0; j<this->numberOfCells(); ++j) {
      const auto& cell_nodes = m_cell_to_node_matrix.rowConst(j);

      switch (cell_nodes.length) {
        case 4: { // tetrahedron
          cell_nb_faces[j] = 4;
          // face 0
          Face f0({cell_nodes(1),
                   cell_nodes(2),
                   cell_nodes(3)});
          face_cells_map[f0].push_back(std::make_tuple(j, 0, f0.reversed()));

          // face 1
          Face f1({cell_nodes(0),
                   cell_nodes(3),
                   cell_nodes(2)});
          face_cells_map[f1].push_back(std::make_tuple(j, 1, f1.reversed()));

          // face 2
          Face f2({cell_nodes(0),
                   cell_nodes(1),
                   cell_nodes(3)});
          face_cells_map[f2].push_back(std::make_tuple(j, 2, f2.reversed()));

          // face 3
          Face f3({cell_nodes(0),
                   cell_nodes(2),
                   cell_nodes(1)});
          face_cells_map[f3].push_back(std::make_tuple(j, 3, f3.reversed()));
          break;
        }
        case 8: { // hexahedron
          // face 0
          Face f0({cell_nodes(3),
                   cell_nodes(2),
                   cell_nodes(1),
                   cell_nodes(0)});
          face_cells_map[f0].push_back(std::make_tuple(j, 0, f0.reversed()));

          // face 1
          Face f1({cell_nodes(4),
                   cell_nodes(5),
                   cell_nodes(6),
                   cell_nodes(7)});
          face_cells_map[f1].push_back(std::make_tuple(j, 1, f1.reversed()));

          // face 2
          Face f2({cell_nodes(0),
                   cell_nodes(4),
                   cell_nodes(7),
                   cell_nodes(3)});
          face_cells_map[f2].push_back(std::make_tuple(j, 2, f2.reversed()));

          // face 3
          Face f3({cell_nodes(1),
                   cell_nodes(2),
                   cell_nodes(6),
                   cell_nodes(5)});
          face_cells_map[f3].push_back(std::make_tuple(j, 3, f3.reversed()));

          // face 4
          Face f4({cell_nodes(0),
                   cell_nodes(1),
                   cell_nodes(5),
                   cell_nodes(4)});
          face_cells_map[f4].push_back(std::make_tuple(j, 4, f4.reversed()));

          // face 5
          Face f5({cell_nodes(3),
                   cell_nodes(7),
                   cell_nodes(6),
                   cell_nodes(2)});
          face_cells_map[f5].push_back(std::make_tuple(j, 5, f5.reversed()));

          cell_nb_faces[j] = 6;
          break;
        }
        default: {
          std::cerr << "unexpected cell type!\n";
          std::exit(0);
        }
      }
    }

    {
      std::vector<std::vector<unsigned int>> face_to_node_vector(face_cells_map.size());
      int l=0;
      for (const auto& face_info : face_cells_map) {
        const Face& face = face_info.first;
        face_to_node_vector[l] = face.nodeIdList();
        ++l;
      }
      m_face_to_node_matrix
          = Kokkos::create_staticcrsgraph<ConnectivityMatrix>("face_to_node_matrix", face_to_node_vector);
    }

    {
      int l=0;
      for (const auto& face_cells_vector : face_cells_map) {
        const Face& face = face_cells_vector.first;
        m_face_number_map[face] = l;
        ++l;
      }
    }

    {
      std::vector<std::vector<unsigned int>> face_to_cell_vector(face_cells_map.size());
      size_t l=0;
      for (const auto& face_cells_vector : face_cells_map) {
        const auto& cells_info_vector = face_cells_vector.second;
        std::vector<unsigned int>& cells_vector = face_to_cell_vector[l];
        cells_vector.resize(cells_info_vector.size());
        for (size_t j=0; j<cells_info_vector.size(); ++j) {
          const auto& [cell_number, local_face_in_cell, reversed] = cells_info_vector[j];
          cells_vector[j] = cell_number;
        }
        ++l;
      }
      m_face_to_cell_matrix
          = Kokkos::create_staticcrsgraph<ConnectivityMatrix>("face_to_cell_matrix", face_to_cell_vector);
    }

    {
      std::vector<std::vector<unsigned short>> face_to_cell_local_face_vector(face_cells_map.size());
      size_t l=0;
      for (const auto& face_cells_vector : face_cells_map) {
        const auto& cells_info_vector = face_cells_vector.second;
        std::vector<unsigned short>& cells_vector = face_to_cell_local_face_vector[l];
        cells_vector.resize(cells_info_vector.size());
        for (size_t j=0; j<cells_info_vector.size(); ++j) {
          const auto& [cell_number, local_face_in_cell, reversed] = cells_info_vector[j];
          cells_vector[j] = local_face_in_cell;
        }
        ++l;
      }
      m_face_to_cell_local_face_matrix
          = Kokkos::create_staticcrsgraph<ConnectivityMatrixShort>("face_to_cell_local_face_matrix",
                                                                   face_to_cell_local_face_vector);
    }

#warning check that the number of cell per faces is <=2

    {
      std::vector<std::vector<unsigned int>> cell_to_face_vector(this->numberOfCells());
      for (size_t j=0; j<cell_to_face_vector.size(); ++j) {
        cell_to_face_vector[j].resize(cell_nb_faces[j]);
      }
      int l=0;
      for (const auto& face_cells_vector : face_cells_map) {
        const auto& cells_vector = face_cells_vector.second;
        for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
          const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
          cell_to_face_vector[cell_number][cell_local_face] = l;
        }
        ++l;
      }
      m_cell_to_face_matrix
          = Kokkos::create_staticcrsgraph<ConnectivityMatrix>("cell_to_face_matrix", cell_to_face_vector);
    }

    {
      std::vector<std::vector<unsigned short>> cell_to_face_is_reversed_vector(this->numberOfCells());
      for (size_t j=0; j<cell_to_face_is_reversed_vector.size(); ++j) {
        cell_to_face_is_reversed_vector[j].resize(cell_nb_faces[j]);
      }
      int l=0;
      for (const auto& face_cells_vector : face_cells_map) {
        const auto& cells_vector = face_cells_vector.second;
        for (unsigned short lj=0; lj<cells_vector.size(); ++lj) {
          const auto& [cell_number, cell_local_face, reversed] = cells_vector[lj];
          cell_to_face_is_reversed_vector[cell_number][cell_local_face] = reversed;
        }
        ++l;
      }

      m_cell_to_face_is_reversed_matrix
          = Kokkos::create_staticcrsgraph<ConnectivityMatrixShort>("cell_to_face_is_reversed_matrix",
                                                                   cell_to_face_is_reversed_vector);
    }

    std::unordered_map<unsigned int, std::vector<unsigned int>> node_faces_map;
    for (size_t l=0; l<m_face_to_node_matrix.numRows(); ++l) {
      const auto& face_nodes = m_face_to_node_matrix.rowConst(l);
      for (size_t lr=0; lr<face_nodes.length; ++lr) {
        const unsigned int r = face_nodes(lr);
        node_faces_map[r].push_back(l);
      }
    }
    Kokkos::View<unsigned short*> node_nb_faces("node_nb_faces", this->numberOfNodes());
    size_t max_nb_face_per_node = 0;
    for (auto node_faces : node_faces_map) {
      max_nb_face_per_node = std::max(node_faces.second.size(), max_nb_face_per_node);
      node_nb_faces[node_faces.first] = node_faces.second.size();
    }
    m_node_nb_faces = node_nb_faces;

    Kokkos::View<unsigned int**> node_faces("node_faces", this->numberOfNodes(), max_nb_face_per_node);
    for (auto node_faces_vector : node_faces_map) {
      const unsigned int r = node_faces_vector.first;
      const std::vector<unsigned int>&  faces_vector = node_faces_vector.second;
      for (size_t l=0; l < faces_vector.size(); ++l) {
        node_faces(r, l) = faces_vector[l];
      }
    }
    m_node_faces = node_faces;
  }

 public:
  void addRefFaceList(const RefFaceList& ref_face_list)
  {
    m_ref_face_list.push_back(ref_face_list);
  }

  size_t numberOfRefFaceList() const
  {
    return m_ref_face_list.size();
  }

  const RefFaceList& refFaceList(const size_t& i) const
  {
    return m_ref_face_list[i];
  }

  void addRefNodeList(const RefNodeList& ref_node_list)
  {
    m_ref_node_list.push_back(ref_node_list);
  }

  size_t numberOfRefNodeList() const
  {
    return m_ref_node_list.size();
  }

  const RefNodeList& refNodeList(const size_t& i) const
  {
    return m_ref_node_list[i];
  }

  KOKKOS_INLINE_FUNCTION
  size_t numberOfNodes() const
  {
    return m_node_to_cell_matrix.numRows();
  }

  KOKKOS_INLINE_FUNCTION
  size_t numberOfFaces() const
  {
    return m_face_to_cell_matrix.numRows();
  }

  KOKKOS_INLINE_FUNCTION
  size_t numberOfCells() const
  {
    return m_cell_to_node_matrix.numRows();
  }

  const Kokkos::View<const double*> invCellNbNodes() const
  {
    return m_inv_cell_nb_nodes;
  }

  unsigned int getFaceNumber(const std::vector<unsigned int>& face_nodes) const
  {
    const Face face(face_nodes);
    auto i_face = m_face_number_map.find(face);
    if (i_face == m_face_number_map.end()) {
      std::cerr << "Face " << face << " not found!\n";
      std::exit(0);
    }
    return i_face->second;
  }

  Connectivity3D(const Connectivity3D&) = delete;

  Connectivity3D(const std::vector<std::vector<unsigned int>>& cell_by_node_vector)
  {
    m_cell_to_node_matrix
        = Kokkos::create_staticcrsgraph<ConnectivityMatrix>("cell_to_node_matrix",
                                                            cell_by_node_vector);

    Assert(this->numberOfCells()>0);

    {
      Kokkos::View<double*> inv_cell_nb_nodes("inv_cell_nb_nodes", this->numberOfCells());
      Kokkos::parallel_for(this->numberOfCells(), KOKKOS_LAMBDA(const int& j){
          const auto& cell_nodes = m_cell_to_node_matrix.rowConst(j);
          inv_cell_nb_nodes[j] = 1./cell_nodes.length;
        });
      m_inv_cell_nb_nodes = inv_cell_nb_nodes;
    }

    {
      std::vector<std::vector<unsigned int>> node_id_per_cell_vector(this->numberOfCells());
      unsigned int id=0;
      for (unsigned int j=0; j<this->numberOfCells(); ++j) {
        const auto& cell_to_node = m_cell_to_node_matrix.rowConst(j);
        auto& node_id_per_cell = node_id_per_cell_vector[j];
        node_id_per_cell.resize(cell_to_node.length);
        for (size_t r=0; r<cell_to_node.length; ++r) {
          node_id_per_cell[r] = id++;
        }
      }
      m_node_id_per_cell_matrix
          = Kokkos::create_staticcrsgraph<ConnectivityMatrix>("node_id_per_cell_matrix",
                                                              node_id_per_cell_vector);
    }

    ConnectivityUtils utils;
    utils.computeNodeCellConnectivity(m_cell_to_node_matrix,
                                      m_node_to_cell_matrix,
                                      m_node_to_cell_local_node_matrix);

    this->_computeFaceCellConnectivities();
  }

  ~Connectivity3D()
  {
    ;
  }
};

#endif // CONNECTIVITY_3D_HPP