#include <mesh/MeshBuilderBase.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ConnectivityDescriptor.hpp>
#include <mesh/ConnectivityDispatcher.hpp>
#include <mesh/ConnectivityDispatcherVariant.hpp>
#include <mesh/ItemId.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshVariant.hpp>
#include <utils/PugsAssert.hpp>
#include <utils/PugsMacros.hpp>

#include <vector>

template <size_t Dimension>
void
MeshBuilderBase::_dispatch()
{
  if (parallel::size() == 1) {
    return;
  }

  using ConnectivityType = Connectivity<Dimension>;
  using Rd               = TinyVector<Dimension>;
  using MeshType         = Mesh<Dimension>;

  if (not m_mesh) {
    ConnectivityDescriptor descriptor;
    std::shared_ptr connectivity = ConnectivityType::build(descriptor);
    NodeValue<Rd> xr;
    m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(connectivity, xr));
  }

  const MeshType& mesh = *(m_mesh->get<const MeshType>());

  auto p_dispatcher = std::make_shared<const ConnectivityDispatcher<Dimension>>(mesh.connectivity());

  m_connectivity_dispatcher = std::make_shared<ConnectivityDispatcherVariant>(p_dispatcher);

  std::shared_ptr dispatched_connectivity = p_dispatcher->dispatchedConnectivity();
  NodeValue<Rd> dispatched_xr             = p_dispatcher->dispatch(mesh.xr());

  m_mesh = std::make_shared<MeshVariant>(std::make_shared<const MeshType>(dispatched_connectivity, dispatched_xr));
}

template void MeshBuilderBase::_dispatch<1>();
template void MeshBuilderBase::_dispatch<2>();
template void MeshBuilderBase::_dispatch<3>();

template <size_t Dimension>
void
MeshBuilderBase::_checkMesh() const
{
  using MeshType         = Mesh<Dimension>;
  using ConnectivityType = typename MeshType::Connectivity;

  if (not m_mesh) {
    throw UnexpectedError("mesh is not built yet");
  }

  const MeshType& mesh = *(m_mesh->get<const MeshType>());

  const ConnectivityType& connectivity = mesh.connectivity();

  if constexpr (Dimension > 2) {   // check for duplicated edges
    auto edge_to_node_matrix = connectivity.edgeToNodeMatrix();

    std::vector<std::vector<EdgeId>> node_edges(mesh.numberOfNodes());
    for (EdgeId edge_id = 0; edge_id < mesh.numberOfEdges(); ++edge_id) {
      auto node_list = edge_to_node_matrix[edge_id];
      for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
        node_edges[node_list[i_node]].push_back(edge_id);
      }
    }

    for (auto&& edge_list : node_edges) {
      std::sort(edge_list.begin(), edge_list.end());
    }

    for (EdgeId edge_id = 0; edge_id < mesh.numberOfEdges(); ++edge_id) {
      auto node_list                   = edge_to_node_matrix[edge_id];
      std::vector<EdgeId> intersection = node_edges[node_list[0]];
      for (size_t i_node = 1; i_node < node_list.size(); ++i_node) {
        std::vector<EdgeId> local_intersection;
        std::set_intersection(intersection.begin(), intersection.end(),   //
                              node_edges[node_list[i_node]].begin(), node_edges[node_list[i_node]].end(),
                              std::back_inserter(local_intersection));
        std::swap(local_intersection, intersection);
        if (intersection.size() < 2) {
          break;
        }
      }

      if (intersection.size() > 1) {
        std::ostringstream error_msg;
        error_msg << "invalid mesh.\n\tFollowing edges\n";
        for (EdgeId duplicated_edge_id : intersection) {
          error_msg << "\t - id=" << duplicated_edge_id << " number=" << connectivity.edgeNumber()[duplicated_edge_id]
                    << '\n';
        }
        error_msg << "\tare duplicated";
        throw NormalError(error_msg.str());
      }
    }
  }

  if constexpr (Dimension > 1) {   // check for duplicated faces
    auto face_to_node_matrix = connectivity.faceToNodeMatrix();

    std::vector<std::vector<FaceId>> node_faces(mesh.numberOfNodes());
    for (FaceId face_id = 0; face_id < mesh.numberOfFaces(); ++face_id) {
      auto node_list = face_to_node_matrix[face_id];
      for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
        node_faces[node_list[i_node]].push_back(face_id);
      }
    }

    for (auto&& face_list : node_faces) {
      std::sort(face_list.begin(), face_list.end());
    }

    for (FaceId face_id = 0; face_id < mesh.numberOfFaces(); ++face_id) {
      auto node_list                   = face_to_node_matrix[face_id];
      std::vector<FaceId> intersection = node_faces[node_list[0]];
      for (size_t i_node = 1; i_node < node_list.size(); ++i_node) {
        std::vector<FaceId> local_intersection;
        std::set_intersection(intersection.begin(), intersection.end(),   //
                              node_faces[node_list[i_node]].begin(), node_faces[node_list[i_node]].end(),
                              std::back_inserter(local_intersection));
        std::swap(local_intersection, intersection);
        if (intersection.size() < 2) {
          break;
        }
      }

      if (intersection.size() > 1) {
        std::ostringstream error_msg;
        error_msg << "invalid mesh.\n\tFollowing faces\n";
        for (FaceId intersection_face_id : intersection) {
          error_msg << "\t - id=" << intersection_face_id
                    << " number=" << connectivity.faceNumber()[intersection_face_id] << '\n';
          error_msg << "\t   nodes:";
          for (size_t i = 0; i < face_to_node_matrix[intersection_face_id].size(); ++i) {
            error_msg << ' ' << face_to_node_matrix[intersection_face_id][i];
          }
          error_msg << '\n';
        }
        error_msg << "\tare duplicated";
        throw NormalError(error_msg.str());
      }
    }
  }

  auto cell_to_node_matrix = connectivity.cellToNodeMatrix();

  {   // check for duplicated cells
    std::vector<std::vector<CellId>> node_cells(mesh.numberOfNodes());
    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      auto node_list = cell_to_node_matrix[cell_id];
      for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
        node_cells[node_list[i_node]].push_back(cell_id);
      }
    }

    for (auto&& cell_list : node_cells) {
      std::sort(cell_list.begin(), cell_list.end());
    }

    for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
      auto node_list                   = cell_to_node_matrix[cell_id];
      std::vector<CellId> intersection = node_cells[node_list[0]];
      for (size_t i_node = 1; i_node < node_list.size(); ++i_node) {
        std::vector<CellId> local_intersection;
        std::set_intersection(intersection.begin(), intersection.end(),   //
                              node_cells[node_list[i_node]].begin(), node_cells[node_list[i_node]].end(),
                              std::back_inserter(local_intersection));
        std::swap(local_intersection, intersection);
        if (intersection.size() < 2) {
          break;
        }
      }

      if (intersection.size() > 1) {
        std::ostringstream error_msg;
        error_msg << "invalid mesh.\n\tFollowing cells\n";
        for (CellId duplicated_cell_id : intersection) {
          error_msg << "\t - id=" << duplicated_cell_id << " number=" << connectivity.cellNumber()[duplicated_cell_id]
                    << '\n';
        }
        error_msg << "\tare duplicated";
        throw NormalError(error_msg.str());
      }
    }
  }

  const auto& Cjr = MeshDataManager::instance().getMeshData(mesh).Cjr();
  const auto& xr  = mesh.xr();

  for (CellId cell_id = 0; cell_id < mesh.numberOfCells(); ++cell_id) {
    double cell_volume = 0;
    auto cell_nodes    = cell_to_node_matrix[cell_id];
    for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
      cell_volume += dot(Cjr(cell_id, i_node), xr[cell_nodes[i_node]]);
    }

    if (cell_volume <= 0) {
      std::ostringstream error_msg;
      error_msg << "invalid mesh.\n\tThe following cell\n";
      error_msg << "\t - id=" << cell_id << " number=" << connectivity.cellNumber()[cell_id] << '\n';
      error_msg << "\thas non-positive volume: " << cell_volume / Dimension;
      throw NormalError(error_msg.str());
    }
  }
}

template void MeshBuilderBase::_checkMesh<1>() const;
template void MeshBuilderBase::_checkMesh<2>() const;
template void MeshBuilderBase::_checkMesh<3>() const;