#include <mesh/StencilBuilder.hpp>

#include <mesh/Connectivity.hpp>
#include <mesh/ItemArray.hpp>
#include <utils/GlobalVariableManager.hpp>
#include <utils/Messenger.hpp>

#include <set>

template <typename ConnectivityType>
Array<const uint32_t>
StencilBuilder::_getRowMap(const ConnectivityType& connectivity) const
{
  auto cell_to_node_matrix = connectivity.cellToNodeMatrix();
  auto node_to_cell_matrix = connectivity.nodeToCellMatrix();

  auto cell_is_owned = connectivity.cellIsOwned();

  Array<uint32_t> row_map{connectivity.numberOfCells() + 1};
  row_map[0] = 0;
  std::vector<CellId> neighbors;
  for (CellId cell_id = 0; cell_id < connectivity.numberOfCells(); ++cell_id) {
    neighbors.resize(0);
    // The stencil is not built for ghost cells
    if (cell_is_owned[cell_id]) {
      auto cell_nodes = cell_to_node_matrix[cell_id];
      for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
        const NodeId node_id = cell_nodes[i_node];
        auto node_cells      = node_to_cell_matrix[node_id];
        for (size_t i_node_cell = 0; i_node_cell < node_cells.size(); ++i_node_cell) {
          const CellId node_cell_id = node_cells[i_node_cell];
          if (node_cell_id != cell_id) {
            neighbors.push_back(node_cells[i_node_cell]);
          }
        }
      }
      std::sort(neighbors.begin(), neighbors.end());
      neighbors.erase(std::unique(neighbors.begin(), neighbors.end()), neighbors.end());
    }
    // The cell itself is not counted
    row_map[cell_id + 1] = row_map[cell_id] + neighbors.size();
  }

  return row_map;
}

template <typename ConnectivityType>
Array<const uint32_t>
StencilBuilder::_getColumnIndices(const ConnectivityType& connectivity, const Array<const uint32_t>& row_map) const
{
  auto cell_number = connectivity.cellNumber();

  Array<uint32_t> max_index(row_map.size() - 1);
  parallel_for(
    max_index.size(), PUGS_LAMBDA(size_t i) { max_index[i] = row_map[i]; });

  auto cell_to_node_matrix = connectivity.cellToNodeMatrix();
  auto node_to_cell_matrix = connectivity.nodeToCellMatrix();

  auto cell_is_owned = connectivity.cellIsOwned();

  Array<uint32_t> column_indices(row_map[row_map.size() - 1]);
  column_indices.fill(std::numeric_limits<uint32_t>::max());

  for (CellId cell_id = 0; cell_id < connectivity.numberOfCells(); ++cell_id) {
    // The stencil is not built for ghost cells
    if (cell_is_owned[cell_id]) {
      auto cell_nodes = cell_to_node_matrix[cell_id];
      for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
        const NodeId node_id = cell_nodes[i_node];
        auto node_cells      = node_to_cell_matrix[node_id];
        for (size_t i_node_cell = 0; i_node_cell < node_cells.size(); ++i_node_cell) {
          const CellId node_cell_id = node_cells[i_node_cell];
          if (node_cell_id != cell_id) {
            bool found = false;
            for (size_t i_index = row_map[cell_id]; i_index < max_index[cell_id]; ++i_index) {
              if (column_indices[i_index] == node_cell_id) {
                found = true;
                break;
              }
            }
            if (not found) {
              int node_cell_number = cell_number[node_cell_id];
              size_t i_index       = row_map[cell_id];
              // search for position for index
              while ((i_index < max_index[cell_id])) {
                if (node_cell_number > cell_number[CellId(column_indices[i_index])]) {
                  ++i_index;
                } else {
                  break;
                }
              }

              for (size_t i_destination = max_index[cell_id]; i_destination > i_index; --i_destination) {
                size_t i_source = i_destination - 1;

                column_indices[i_destination] = column_indices[i_source];
              }
              ++max_index[cell_id];
              column_indices[i_index] = node_cell_id;
            }
          }
        }
      }
    }
  }

  return column_indices;
}

template <typename ConnectivityType>
CellToCellStencilArray
StencilBuilder::_buildC2C(const ConnectivityType& connectivity,
                          size_t number_of_layers,
                          const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const
{
  if ((parallel::size() > 1) and (number_of_layers > GlobalVariableManager::instance().getNumberOfGhostLayers())) {
    std::ostringstream error_msg;
    error_msg << "Stencil builder requires" << rang::fgB::yellow << number_of_layers << rang::fg::reset
              << " layers while parallel number of ghost layer is "
              << GlobalVariableManager::instance().getNumberOfGhostLayers() << ".\n";
    error_msg << "Increase the number of ghost layers (using the '--number-of-ghost-layers' option).";
    throw NormalError(error_msg.str());
  }

  if (number_of_layers > 2) {
    throw NotImplementedError("number of layers too large");
  }

  if (symmetry_boundary_descriptor_list.size() == 0) {
    if (number_of_layers == 1) {
      Array<const uint32_t> row_map        = this->_getRowMap(connectivity);
      Array<const uint32_t> column_indices = this->_getColumnIndices(connectivity, row_map);

      return {ConnectivityMatrix{row_map, column_indices}, {}};
    } else {
      auto cell_to_node_matrix = connectivity.cellToNodeMatrix();
      auto node_to_cell_matrix = connectivity.nodeToCellMatrix();

      auto cell_is_owned = connectivity.cellIsOwned();
      auto cell_number   = connectivity.cellNumber();

      Array<uint32_t> row_map{connectivity.numberOfCells() + 1};
      row_map[0] = 0;

      std::vector<CellId> column_indices_vector;

      for (CellId cell_id = 0; cell_id < connectivity.numberOfCells(); ++cell_id) {
        if (cell_is_owned[cell_id]) {
          std::set<CellId, std::function<bool(CellId, CellId)>> cell_set(
            [=](CellId cell_0, CellId cell_1) { return cell_number[cell_0] < cell_number[cell_1]; });

          for (size_t i_node_1 = 0; i_node_1 < cell_to_node_matrix[cell_id].size(); ++i_node_1) {
            const NodeId layer_1_node_id = cell_to_node_matrix[cell_id][i_node_1];

            for (size_t i_cell_1 = 0; i_cell_1 < node_to_cell_matrix[layer_1_node_id].size(); ++i_cell_1) {
              CellId cell_1_id = node_to_cell_matrix[layer_1_node_id][i_cell_1];

              for (size_t i_node_2 = 0; i_node_2 < cell_to_node_matrix[cell_1_id].size(); ++i_node_2) {
                const NodeId layer_2_node_id = cell_to_node_matrix[cell_1_id][i_node_2];

                for (size_t i_cell_2 = 0; i_cell_2 < node_to_cell_matrix[layer_2_node_id].size(); ++i_cell_2) {
                  CellId cell_2_id = node_to_cell_matrix[layer_2_node_id][i_cell_2];

                  if (cell_2_id != cell_id) {
                    cell_set.insert(cell_2_id);
                  }
                }
              }
            }
          }

          for (auto stencil_cell_id : cell_set) {
            column_indices_vector.push_back(stencil_cell_id);
          }
          row_map[cell_id + 1] = row_map[cell_id] + cell_set.size();
        }
      }

      if (row_map[row_map.size() - 1] != column_indices_vector.size()) {
        throw UnexpectedError("incorrect stencil size");
      }

      Array<uint32_t> column_indices(row_map[row_map.size() - 1]);
      column_indices.fill(std::numeric_limits<uint32_t>::max());

      for (size_t i = 0; i < column_indices.size(); ++i) {
        column_indices[i] = column_indices_vector[i];
      }
      ConnectivityMatrix primal_stencil{row_map, column_indices};

      return {primal_stencil, {}};
    }
  } else {
    if constexpr (ConnectivityType::Dimension > 1) {
      std::vector<Array<const FaceId>> boundary_node_list;

      NodeArray<bool> symmetry_node_list(connectivity, symmetry_boundary_descriptor_list.size());
      symmetry_node_list.fill(0);

      auto face_to_node_matrix = connectivity.faceToNodeMatrix();
      auto cell_to_node_matrix = connectivity.cellToNodeMatrix();
      auto node_to_cell_matrix = connectivity.nodeToCellMatrix();

      {
        size_t i_symmetry_boundary = 0;
        for (auto p_boundary_descriptor : symmetry_boundary_descriptor_list) {
          const IBoundaryDescriptor& boundary_descriptor = *p_boundary_descriptor;

          bool found = false;
          for (size_t i_ref_node_list = 0;
               i_ref_node_list < connectivity.template numberOfRefItemList<ItemType::face>(); ++i_ref_node_list) {
            const auto& ref_face_list = connectivity.template refItemList<ItemType::face>(i_ref_node_list);
            if (ref_face_list.refId() == boundary_descriptor) {
              found = true;
              boundary_node_list.push_back(ref_face_list.list());
              for (size_t i_face = 0; i_face < ref_face_list.list().size(); ++i_face) {
                const FaceId face_id = ref_face_list.list()[i_face];
                auto node_list       = face_to_node_matrix[face_id];
                for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                  const NodeId node_id = node_list[i_node];

                  symmetry_node_list[node_id][i_symmetry_boundary] = true;
                }
              }
              break;
            }
          }
          ++i_symmetry_boundary;
          if (not found) {
            std::ostringstream error_msg;
            error_msg << "cannot find boundary '" << rang::fgB::yellow << boundary_descriptor << rang::fg::reset
                      << '\'';
            throw NormalError(error_msg.str());
          }
        }
      }

      auto cell_is_owned = connectivity.cellIsOwned();
      auto cell_number   = connectivity.cellNumber();

      Array<uint32_t> row_map{connectivity.numberOfCells() + 1};
      row_map[0] = 0;
      std::vector<Array<uint32_t>> symmetry_row_map_list(symmetry_boundary_descriptor_list.size());
      for (auto&& symmetry_row_map : symmetry_row_map_list) {
        symmetry_row_map    = Array<uint32_t>{connectivity.numberOfCells() + 1};
        symmetry_row_map[0] = 0;
      }

      std::vector<uint32_t> column_indices_vector;
      std::vector<std::vector<uint32_t>> symmetry_column_indices_vector(symmetry_boundary_descriptor_list.size());

      for (CellId cell_id = 0; cell_id < connectivity.numberOfCells(); ++cell_id) {
        std::set<CellId> cell_set;
        std::vector<std::set<CellId>> by_boundary_symmetry_cell(symmetry_boundary_descriptor_list.size());

        if (cell_is_owned[cell_id]) {
          auto cell_node_list = cell_to_node_matrix[cell_id];
          for (size_t i_cell_node = 0; i_cell_node < cell_node_list.size(); ++i_cell_node) {
            const NodeId cell_node_id = cell_node_list[i_cell_node];
            auto node_cell_list       = node_to_cell_matrix[cell_node_id];
            for (size_t i_node_cell = 0; i_node_cell < node_cell_list.size(); ++i_node_cell) {
              const CellId node_cell_id = node_cell_list[i_node_cell];
              if (cell_id != node_cell_id) {
                cell_set.insert(node_cell_id);
              }
            }
          }

          {
            std::vector<CellId> cell_vector;
            for (auto&& set_cell_id : cell_set) {
              cell_vector.push_back(set_cell_id);
            }
            std::sort(cell_vector.begin(), cell_vector.end(),
                      [&cell_number](const CellId& cell0_id, const CellId& cell1_id) {
                        return cell_number[cell0_id] < cell_number[cell1_id];
                      });

            for (auto&& vector_cell_id : cell_vector) {
              column_indices_vector.push_back(vector_cell_id);
            }
          }

          for (size_t i = 0; i < symmetry_boundary_descriptor_list.size(); ++i) {
            std::set<CellId> symmetry_cell_set;
            for (size_t i_cell_node = 0; i_cell_node < cell_node_list.size(); ++i_cell_node) {
              const NodeId cell_node_id = cell_node_list[i_cell_node];
              if (symmetry_node_list[cell_node_id][i]) {
                auto node_cell_list = node_to_cell_matrix[cell_node_id];
                for (size_t i_node_cell = 0; i_node_cell < node_cell_list.size(); ++i_node_cell) {
                  const CellId node_cell_id = node_cell_list[i_node_cell];
                  symmetry_cell_set.insert(node_cell_id);
                }
              }
            }
            by_boundary_symmetry_cell[i] = symmetry_cell_set;

            std::vector<CellId> cell_vector;
            for (auto&& set_cell_id : symmetry_cell_set) {
              cell_vector.push_back(set_cell_id);
            }
            std::sort(cell_vector.begin(), cell_vector.end(),
                      [&cell_number](const CellId& cell0_id, const CellId& cell1_id) {
                        return cell_number[cell0_id] < cell_number[cell1_id];
                      });

            for (auto&& vector_cell_id : cell_vector) {
              symmetry_column_indices_vector[i].push_back(vector_cell_id);
            }
          }
        }
        row_map[cell_id + 1] = row_map[cell_id] + cell_set.size();

        for (size_t i = 0; i < symmetry_row_map_list.size(); ++i) {
          symmetry_row_map_list[i][cell_id + 1] =
            symmetry_row_map_list[i][cell_id] + by_boundary_symmetry_cell[i].size();
        }
      }
      ConnectivityMatrix primal_stencil{row_map, convert_to_array(column_indices_vector)};

      CellToCellStencilArray::BoundaryDescriptorStencilArrayList symmetry_boundary_stencil_list;
      {
        size_t i = 0;
        for (auto&& p_boundary_descriptor : symmetry_boundary_descriptor_list) {
          symmetry_boundary_stencil_list.emplace_back(
            CellToCellStencilArray::
              BoundaryDescriptorStencilArray{p_boundary_descriptor,
                                             ConnectivityMatrix{symmetry_row_map_list[i],
                                                                convert_to_array(symmetry_column_indices_vector[i])}});
          ++i;
        }
      }

      return {{primal_stencil}, {symmetry_boundary_stencil_list}};

    } else {
      throw NotImplementedError("Only implemented in 2D/3D");
    }
  }
}

CellToCellStencilArray
StencilBuilder::buildC2C(const IConnectivity& connectivity,
                         const StencilDescriptor& stencil_descriptor,
                         const BoundaryDescriptorList& symmetry_boundary_descriptor_list) const
{
  switch (connectivity.dimension()) {
  case 1: {
    return StencilBuilder::_buildC2C(dynamic_cast<const Connectivity<1>&>(connectivity),
                                     stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
  }
  case 2: {
    return StencilBuilder::_buildC2C(dynamic_cast<const Connectivity<2>&>(connectivity),
                                     stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
  }
  case 3: {
    return StencilBuilder::_buildC2C(dynamic_cast<const Connectivity<3>&>(connectivity),
                                     stencil_descriptor.numberOfLayers(), symmetry_boundary_descriptor_list);
  }
  default: {
    throw UnexpectedError("invalid connectivity dimension");
  }
  }
}

CellToFaceStencilArray
StencilBuilder::buildC2F(const IConnectivity&, const StencilDescriptor&, const BoundaryDescriptorList&) const
{
  throw NotImplementedError("cell to face stencil");
}

NodeToCellStencilArray
StencilBuilder::buildN2C(const IConnectivity&, const StencilDescriptor&, const BoundaryDescriptorList&) const
{
  throw NotImplementedError("node to cell stencil");
}