#ifndef MESH_RANDOMIZER_HPP
#define MESH_RANDOMIZER_HPP

#include <algebra/TinyMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <scheme/AxisBoundaryConditionDescriptor.hpp>
#include <scheme/FixedBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
#include <utils/RandomEngine.hpp>

#include <variant>
#include <vector>

template <size_t Dimension>
class MeshRandomizer
{
 private:
  using Rd               = TinyVector<Dimension>;
  using Rdxd             = TinyMatrix<Dimension>;
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;

  const MeshType& m_given_mesh;

  class AxisBoundaryCondition;
  class FixedBoundaryCondition;
  class SymmetryBoundaryCondition;

  using BoundaryCondition = std::variant<AxisBoundaryCondition, FixedBoundaryCondition, SymmetryBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;
  BoundaryConditionList m_boundary_condition_list;

  BoundaryConditionList
  _getBCList(const MeshType& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    constexpr ItemType EdgeType = [] {
      if constexpr (Dimension == 3) {
        return ItemType::edge;
      } else if constexpr (Dimension == 2) {
        return ItemType::face;
      } else {
        return ItemType::node;
      }
    }();

    constexpr ItemType FaceType = [] {
      if constexpr (Dimension > 1) {
        return ItemType::face;
      } else {
        return ItemType::node;
      }
    }();

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::axis: {
        const AxisBoundaryConditionDescriptor& axis_bc_descriptor =
          dynamic_cast<const AxisBoundaryConditionDescriptor&>(*bc_descriptor);
        for (size_t i_ref_edge_list = 0; i_ref_edge_list < mesh.connectivity().template numberOfRefItemList<EdgeType>();
             ++i_ref_edge_list) {
          const auto& ref_edge_list = mesh.connectivity().template refItemList<EdgeType>(i_ref_edge_list);
          const RefId& ref          = ref_edge_list.refId();
          if (ref == axis_bc_descriptor.boundaryDescriptor()) {
            if constexpr (Dimension == 1) {
              bc_list.emplace_back(FixedBoundaryCondition{MeshNodeBoundary<Dimension>{mesh, ref_edge_list}});
            } else {
              bc_list.emplace_back(AxisBoundaryCondition{MeshLineNodeBoundary<Dimension>(mesh, ref_edge_list)});
            }
          }
        }
        is_valid_boundary_condition = true;
        break;
      }
      case IBoundaryConditionDescriptor::Type::symmetry: {
        const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
          dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
        for (size_t i_ref_face_list = 0; i_ref_face_list < mesh.connectivity().template numberOfRefItemList<FaceType>();
             ++i_ref_face_list) {
          const auto& ref_face_list = mesh.connectivity().template refItemList<FaceType>(i_ref_face_list);
          const RefId& ref          = ref_face_list.refId();
          if (ref == sym_bc_descriptor.boundaryDescriptor()) {
            bc_list.emplace_back(SymmetryBoundaryCondition{MeshFlatNodeBoundary<Dimension>(mesh, ref_face_list)});
          }
        }
        is_valid_boundary_condition = true;
        break;
      }
      case IBoundaryConditionDescriptor::Type::fixed: {
        const FixedBoundaryConditionDescriptor& fixed_bc_descriptor =
          dynamic_cast<const FixedBoundaryConditionDescriptor&>(*bc_descriptor);
        for (size_t i_ref_face_list = 0; i_ref_face_list < mesh.connectivity().template numberOfRefItemList<FaceType>();
             ++i_ref_face_list) {
          const auto& ref_face_list = mesh.connectivity().template refItemList<FaceType>(i_ref_face_list);
          const RefId& ref          = ref_face_list.refId();
          if (ref == fixed_bc_descriptor.boundaryDescriptor()) {
            bc_list.emplace_back(FixedBoundaryCondition{MeshNodeBoundary<Dimension>{mesh, ref_face_list}});
          }
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for mesh randomizer";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  void
  _applyBC(NodeValue<Rd>& shift) const
  {
    for (auto&& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using BCType = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<BCType, SymmetryBoundaryCondition>) {
            const Rd& n = bc.outgoingNormal();

            const Rdxd I   = identity;
            const Rdxd nxn = tensorProduct(n, n);
            const Rdxd P   = I - nxn;

            const Array<const NodeId>& node_list = bc.nodeList();
            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];

                shift[node_id] = P * shift[node_id];
              });

          } else if constexpr (std::is_same_v<BCType, AxisBoundaryCondition>) {
            const Rd& t = bc.direction();

            const Rdxd txt = tensorProduct(t, t);

            const Array<const NodeId>& node_list = bc.nodeList();
            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];

                shift[node_id] = txt * shift[node_id];
              });

          } else if constexpr (std::is_same_v<BCType, FixedBoundaryCondition>) {
            const Array<const NodeId>& node_list = bc.nodeList();
            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];
                shift[node_id]       = zero;
              });

          } else {
            throw UnexpectedError("invalid boundary condition type");
          }
        },
        boundary_condition);
    }
  }

  NodeValue<Rd>
  _getDisplacement() const
  {
    const ConnectivityType& connectivity = m_given_mesh.connectivity();
    NodeValue<const Rd> given_xr         = m_given_mesh.xr();

    auto node_to_cell_matrix        = connectivity.nodeToCellMatrix();
    auto cell_to_node_matrix        = connectivity.cellToNodeMatrix();
    auto node_number_in_their_cells = connectivity.nodeLocalNumbersInTheirCells();

    NodeValue<double> max_delta_xr{connectivity};
    parallel_for(
      connectivity.numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) {
        const Rd& x0 = given_xr[node_id];

        const auto& node_cell_list = node_to_cell_matrix[node_id];
        double min_distance_2      = std::numeric_limits<double>::max();

        for (size_t i_cell = 0; i_cell < node_cell_list.size(); ++i_cell) {
          const size_t i_cell_node = node_number_in_their_cells(node_id, i_cell);

          const CellId cell_id       = node_cell_list[i_cell];
          const auto& cell_node_list = cell_to_node_matrix[cell_id];

          for (size_t i_node = 0; i_node < cell_node_list.size(); ++i_node) {
            if (i_node != i_cell_node) {
              const NodeId cell_node_id = cell_node_list[i_node];
              const Rd delta            = x0 - given_xr[cell_node_id];
              min_distance_2            = std::min(min_distance_2, dot(delta, delta));
            }
          }
        }
        double max_delta = std::sqrt(min_distance_2);

        max_delta_xr[node_id] = max_delta;
      });

    synchronize(max_delta_xr);

    std::uniform_real_distribution<> distribution(-0.45, 0.45);

    NodeValue<const int> node_numbers = connectivity.nodeNumber();
    using IdCorrespondance            = std::pair<int, NodeId>;
    Array<IdCorrespondance> node_numbers_to_node_id{node_numbers.numberOfItems()};
    parallel_for(
      node_numbers.numberOfItems(), PUGS_LAMBDA(const NodeId node_id) {
        node_numbers_to_node_id[node_id] = std::make_pair(node_numbers[node_id], node_id);
      });

    std::sort(&node_numbers_to_node_id[0], &node_numbers_to_node_id[0] + node_numbers_to_node_id.size(),
              [](IdCorrespondance a, IdCorrespondance b) { return a.first < b.first; });

    RandomEngine& random_engine = RandomEngine::instance();

    Assert(isSynchronized(random_engine), "seed is not synchronized when entering mesh randomization");

    NodeValue<Rd> shift_r{connectivity};

    int i_node_number = 0;
    for (size_t i = 0; i < node_numbers_to_node_id.size(); ++i) {
      const auto [node_number, node_id] = node_numbers_to_node_id[i];
      while (i_node_number < node_number) {
        for (size_t j = 0; j < Dimension; ++j) {
          distribution(random_engine.engine());
        }
        ++i_node_number;
      }

      double max_delta = max_delta_xr[node_id];

      Rd shift;
      for (size_t i_component = 0; i_component < Dimension; ++i_component) {
        shift[i_component] = max_delta * distribution(random_engine.engine());
      }

      shift_r[node_id] = shift;

      ++i_node_number;
    }

    const int max_node_number =
      parallel::allReduceMax(node_numbers_to_node_id[node_numbers_to_node_id.size() - 1].first);

    // Advances random engine to preserve CPU random number generators synchronization
    for (; i_node_number <= max_node_number; ++i_node_number) {
      for (size_t j = 0; j < Dimension; ++j) {
        distribution(random_engine.engine());
      }
    }

    this->_applyBC(shift_r);

#ifndef NDEBUG
    if (not isSynchronized(shift_r)) {
      throw UnexpectedError("randomized mesh coordinates are not synchronized");
    }
#endif   // NDEBUG

    Assert(isSynchronized(random_engine), "seed is not synchronized after mesh randomization");

    return shift_r;
  }

 public:
  std::shared_ptr<const MeshType>
  getRandomizedMesh() const
  {
    NodeValue<const Rd> given_xr = m_given_mesh.xr();

    NodeValue<Rd> xr = this->_getDisplacement();

    parallel_for(
      m_given_mesh.numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) { xr[node_id] += given_xr[node_id]; });

    return std::make_shared<MeshType>(m_given_mesh.shared_connectivity(), xr);
  }

  std::shared_ptr<const MeshType>
  getRandomizedMesh(const FunctionSymbolId& function_symbol_id) const
  {
    NodeValue<const Rd> given_xr = m_given_mesh.xr();

    NodeValue<const bool> is_displaced =
      InterpolateItemValue<bool(const Rd)>::interpolate(function_symbol_id, given_xr);

    NodeValue<Rd> xr = this->_getDisplacement();

    parallel_for(
      m_given_mesh.numberOfNodes(),
      PUGS_LAMBDA(const NodeId node_id) { xr[node_id] = is_displaced[node_id] * xr[node_id] + given_xr[node_id]; });

    return std::make_shared<MeshType>(m_given_mesh.shared_connectivity(), xr);
  }

  MeshRandomizer(const MeshRandomizer&) = delete;
  MeshRandomizer(MeshRandomizer&&)      = delete;

  MeshRandomizer(const MeshType& given_mesh,
                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
    : m_given_mesh(given_mesh), m_boundary_condition_list(this->_getBCList(given_mesh, bc_descriptor_list))
  {}

  ~MeshRandomizer() = default;
};

template <size_t Dimension>
class MeshRandomizer<Dimension>::AxisBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshLineNodeBoundary<Dimension> m_mesh_line_node_boundary;

 public:
  const Rd&
  direction() const
  {
    return m_mesh_line_node_boundary.direction();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_line_node_boundary.nodeList();
  }

  AxisBoundaryCondition(MeshLineNodeBoundary<Dimension>&& mesh_line_node_boundary)
    : m_mesh_line_node_boundary(mesh_line_node_boundary)
  {
    ;
  }

  ~AxisBoundaryCondition() = default;
};

template <size_t Dimension>
class MeshRandomizer<Dimension>::FixedBoundaryCondition
{
 private:
  const MeshNodeBoundary<Dimension> m_mesh_node_boundary;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_node_boundary.nodeList();
  }

  FixedBoundaryCondition(MeshNodeBoundary<Dimension>&& mesh_node_boundary) : m_mesh_node_boundary{mesh_node_boundary} {}

  ~FixedBoundaryCondition() = default;
};

template <size_t Dimension>
class MeshRandomizer<Dimension>::SymmetryBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshFlatNodeBoundary<Dimension> m_mesh_flat_node_boundary;

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_mesh_flat_node_boundary.outgoingNormal();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_flat_node_boundary.nodeList();
  }

  SymmetryBoundaryCondition(MeshFlatNodeBoundary<Dimension>&& mesh_flat_node_boundary)
    : m_mesh_flat_node_boundary(mesh_flat_node_boundary)
  {
    ;
  }

  ~SymmetryBoundaryCondition() = default;
};

#endif   // MESH_RANDOMIZER_HPP