#include <mesh/ImplicitMeshSmoother.hpp>

#include <algebra/TinyMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshCellZone.hpp>
#include <mesh/MeshFlatNodeBoundary.hpp>
#include <mesh/MeshLineNodeBoundary.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <scheme/AxisBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <scheme/FixedBoundaryConditionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
#include <utils/RandomEngine.hpp>

#include <variant>

template <size_t Dimension>
class ImplicitMeshSmootherHandler::ImplicitMeshSmoother
{
 private:
  using Rd               = TinyVector<Dimension>;
  using Rdxd             = TinyMatrix<Dimension>;
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;

  const MeshType& m_given_mesh;

  //  class AxisBoundaryCondition;
  class FixedBoundaryCondition;
  class SymmetryBoundaryCondition;

  //  using BoundaryCondition = std::variant<AxisBoundaryCondition, FixedBoundaryCondition, SymmetryBoundaryCondition>;
  using BoundaryCondition = std::variant<FixedBoundaryCondition, SymmetryBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;
  BoundaryConditionList m_boundary_condition_list;

  BoundaryConditionList
  _getBCList(const MeshType& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      switch (bc_descriptor->type()) {
        // case IBoundaryConditionDescriptor::Type::axis: {
        //   if constexpr (Dimension == 1) {
        //     bc_list.emplace_back(FixedBoundaryCondition{getMeshNodeBoundary(mesh,
        //     bc_descriptor->boundaryDescriptor())});
        //   } else {
        //     bc_list.emplace_back(
        //       AxisBoundaryCondition{getMeshLineNodeBoundary(mesh, bc_descriptor->boundaryDescriptor())});
        //   }
        //   break;
        // }
      case IBoundaryConditionDescriptor::Type::symmetry: {
        bc_list.emplace_back(
          SymmetryBoundaryCondition{getMeshFlatNodeBoundary(mesh, bc_descriptor->boundaryDescriptor())});
        break;
      }
      case IBoundaryConditionDescriptor::Type::fixed: {
        bc_list.emplace_back(FixedBoundaryCondition{getMeshNodeBoundary(mesh, bc_descriptor->boundaryDescriptor())});
        break;
      }
      default: {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for mesh smoother";
        throw NormalError(error_msg.str());
      }
      }
    }

    return bc_list;
  }

  void
  _browseBC(NodeValue<bool>& is_fixed, NodeValue<int>& is_symmetric) const
  {
    int n_sym = 1;
    for (auto&& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using BCType = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<BCType, SymmetryBoundaryCondition>) {
            //   const Rd& n = bc.outgoingNormal();

            //   const Rdxd I   = identity;
            //   const Rdxd nxn = tensorProduct(n, n);
            //   const Rdxd P   = I - nxn;

            const Array<const NodeId>& node_list = bc.nodeList();
            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];
                if (is_symmetric[node_id] != 0) {
                  is_fixed[node_id] = true;
                } else if (!is_fixed[node_id]) {
                  is_symmetric[node_id] = n_sym;
                }
              });
            n_sym += 1;
            // } else if constexpr (std::is_same_v<BCType, AxisBoundaryCondition>) {
            //   if constexpr (Dimension > 1) {
            //     const Rd& t = bc.direction();

            //     const Rdxd txt = tensorProduct(t, t);

            //     const Array<const NodeId>& node_list = bc.nodeList();
            //     parallel_for(
            //       node_list.size(), PUGS_LAMBDA(const size_t i_node) {
            //         const NodeId node_id = node_list[i_node];

            //         shift[node_id] = txt * shift[node_id];
            //       });
            //   } else {
            //     throw UnexpectedError("AxisBoundaryCondition make no sense in dimension 1");
            //   }

          } else if constexpr (std::is_same_v<BCType, FixedBoundaryCondition>) {
            const Array<const NodeId>& node_list = bc.nodeList();
            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];
                //                shift[node_id]       = zero;
                is_fixed[node_id] = true;
              });

          } else {
            throw UnexpectedError("invalid boundary condition type");
          }
        },
        boundary_condition);
    }
  }

  void
  _computeMatrixSize(const NodeValue<bool>& is_fixed,
                     NodeValue<int>& id_fixed,
                     NodeValue<int>& id_free,
                     size_t& nb_free,
                     size_t& nb_fixed) const
  {
    nb_free  = 0;
    nb_fixed = 0;
    for (NodeId n_id = 0; n_id < m_given_mesh.numberOfNodes(); n_id++) {
      if (is_fixed[n_id]) {
        id_fixed[n_id] = nb_fixed;
        nb_fixed += 1;
      } else {
        id_free[n_id] = nb_free;
        nb_free += 1;
      }
    }
  }
  void
  _findLocalIds(const NodeValue<bool>& is_fixed,
                Array<NodeId>& gid_node_free,
                Array<NodeId>& gid_node_fixed,
                Array<int>& non_zeros_free,
                Array<int>& non_zeros_fixed) const
  {
    const auto& node_to_face_matrix = m_given_mesh.connectivity().nodeToFaceMatrix();
    const auto& face_to_node_matrix = m_given_mesh.connectivity().faceToNodeMatrix();

    int local_free_id  = 0;
    int local_fixed_id = 0;
    non_zeros_free.fill(1);
    // Warning this is not correct
    non_zeros_fixed.fill(0);
    for (NodeId n_id = 0; n_id < m_given_mesh.numberOfNodes(); n_id++) {
      if (is_fixed[n_id]) {
        gid_node_fixed[local_fixed_id] = n_id;
        // for (size_t i_face = 0; i_face < node_to_face_matrix[n_id].size(); ++i_face) {
        //   FaceId face_id = node_to_face_matrix[n_id][i_face];

        //   for (size_t i_node = 0; i_node < face_to_node_matrix[face_id].size(); ++i_node) {
        //     NodeId node_id = face_to_node_matrix[face_id][i_node];
        //     if ((node_id == n_id) or (is_fixed[node_id])) {
        //       continue;
        //     } else {
        //       non_zeros_fixed[local_fixed_id] += 1;
        //     }
        //   }
        // }
        local_fixed_id++;
      } else {
        gid_node_free[local_free_id] = n_id;
        for (size_t i_face = 0; i_face < node_to_face_matrix[n_id].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[n_id][i_face];
          for (size_t i_node = 0; i_node < face_to_node_matrix[face_id].size(); ++i_node) {
            NodeId node_id = face_to_node_matrix[face_id][i_node];
            if ((node_id == n_id)) {
              continue;
            } else if ((is_fixed[node_id])) {
              non_zeros_fixed[local_free_id] += 1;
            } else {
              non_zeros_free[local_free_id] += 1;
            }
          }
        }
        local_free_id++;
      }
    }
  }
  void
  _fillMatrix(CRSMatrixDescriptor<double>& Afree,
              CRSMatrixDescriptor<double>& Afixed,
              const NodeValue<bool>& is_fixed,
              const NodeValue<int>& is_symmetric,
              const NodeValue<int>& id_free,
              const NodeValue<int>& id_fixed) const
  // ,
  //           const Array<NodeId>& gid_node_free,
  //           const Array<NodeId>& gid_node_fixed) const
  {
    const auto& node_to_face_matrix = m_given_mesh.connectivity().nodeToFaceMatrix();
    const auto& face_to_node_matrix = m_given_mesh.connectivity().faceToNodeMatrix();
    int local_free_id               = 0;
    for (NodeId n_id = 0; n_id < m_given_mesh.numberOfNodes(); n_id++) {
      if (is_fixed[n_id]) {
        continue;
      } else {
        Assert(id_free[n_id] == local_free_id, "bad matrix definition");
        if (is_symmetric[n_id]) {
          for (size_t i_face = 0; i_face < node_to_face_matrix[n_id].size(); ++i_face) {
            FaceId face_id = node_to_face_matrix[n_id][i_face];
            for (size_t i_node = 0; i_node < face_to_node_matrix[face_id].size(); ++i_node) {
              NodeId node_id = face_to_node_matrix[face_id][i_node];
              if (node_id == n_id) {
                continue;
              } else if ((is_fixed[node_id])) {
                Afree(local_free_id, local_free_id) += 1;
                Afixed(local_free_id, id_fixed[node_id]) -= 1;
              } else if (is_symmetric[node_id] == is_symmetric[n_id]) {
                Afree(local_free_id, local_free_id) += 1;
                Afree(local_free_id, id_free[node_id]) -= 1;
              }
            }
          }

        } else {
          for (size_t i_face = 0; i_face < node_to_face_matrix[n_id].size(); ++i_face) {
            FaceId face_id = node_to_face_matrix[n_id][i_face];
            for (size_t i_node = 0; i_node < face_to_node_matrix[face_id].size(); ++i_node) {
              NodeId node_id = face_to_node_matrix[face_id][i_node];
              if (node_id == n_id) {
                continue;
              } else if ((is_fixed[node_id])) {
                Afree(local_free_id, local_free_id) += 1;
                Afixed(local_free_id, id_fixed[node_id]) -= 1;
              } else {
                Afree(local_free_id, local_free_id) += 1;
                Afree(local_free_id, id_free[node_id]) -= 1;
              }
            }
          }
        }
        local_free_id++;
      }
    }
  }

  void
  _correct_sym(const NodeValue<const Rd>& old_xr, NodeValue<Rd>& new_xr) const
  {
    for (auto&& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using BCType = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<BCType, SymmetryBoundaryCondition>) {
            const Rd& n = bc.outgoingNormal();

            const Rdxd I   = identity;
            const Rdxd nxn = tensorProduct(n, n);
            const Rdxd P   = I - nxn;

            const Array<const NodeId>& node_list = bc.nodeList();
            parallel_for(
              node_list.size(), PUGS_LAMBDA(const size_t i_node) {
                const NodeId node_id = node_list[i_node];
                Rd depl              = new_xr[node_id] - old_xr[node_id];
                new_xr[node_id]      = old_xr[node_id] + P * depl;
              });
            // } else if constexpr (std::is_same_v<BCType, AxisBoundaryCondition>) {
            //   if constexpr (Dimension > 1) {
            //     const Rd& t = bc.direction();

            //     const Rdxd txt = tensorProduct(t, t);

            //     const Array<const NodeId>& node_list = bc.nodeList();
            //     parallel_for(
            //       node_list.size(), PUGS_LAMBDA(const size_t i_node) {
            //         const NodeId node_id = node_list[i_node];

            //         shift[node_id] = txt * shift[node_id];
            //       });
            //   } else {
            //     throw UnexpectedError("AxisBoundaryCondition make no sense in dimension 1");
            //   }
          }
        },
        boundary_condition);
    }
  }

  NodeValue<Rd>
  _getPosition(NodeValue<const bool> is_displaced) const
  {
    const ConnectivityType& connectivity = m_given_mesh.connectivity();
    NodeValue<const Rd> given_xr         = m_given_mesh.xr();

    NodeValue<Rd> pos_r{connectivity};
    NodeValue<bool> is_fixed{connectivity};
    NodeValue<int> is_symmetric{connectivity};
    is_fixed.fill(false);
    is_symmetric.fill(0);
    parallel_for(
      m_given_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
        pos_r[node_id] = given_xr[node_id];
        if (not is_displaced[node_id]) {
          is_fixed[node_id] = true;
        }
      });

    _browseBC(is_fixed, is_symmetric);
    // std::cout << "is_fixed " << is_fixed << "\n";
    // std::cout << "is_symmetric " << is_symmetric << "\n";
    NodeValue<int> node_dof_id{connectivity};
    size_t nb_free, nb_fixed;
    NodeValue<int> id_free{connectivity};
    NodeValue<int> id_fixed{connectivity};
    _computeMatrixSize(is_fixed, id_fixed, id_free, nb_free, nb_fixed);
    std::cout << " nb_free " << nb_free << " nb_fixed " << nb_fixed << "\n";
    Array<int> non_zeros_free{nb_free};
    Array<int> non_zeros_fixed{nb_free};
    Array<NodeId> gid_node_free{nb_free};
    Array<NodeId> gid_node_fixed{nb_fixed};
    _findLocalIds(is_fixed, gid_node_free, gid_node_fixed, non_zeros_free, non_zeros_fixed);

    CRSMatrixDescriptor<double> Afree(nb_free, nb_free, non_zeros_free);
    CRSMatrixDescriptor<double> Afixed(nb_free, nb_fixed, non_zeros_fixed);
    LinearSolver solver;
    _fillMatrix(Afree, Afixed, is_fixed, is_symmetric, id_free, id_fixed);
    Vector<double> F{nb_fixed};
    CRSMatrix Mfree{Afree.getCRSMatrix()};
    CRSMatrix Mfixed{Afixed.getCRSMatrix()};
    for (size_t dir = 0; dir < Dimension; ++dir) {
      for (size_t lid_node = 0; lid_node < nb_fixed; lid_node++) {
        F[lid_node] = -given_xr[gid_node_fixed[lid_node]][dir];
      }
      Vector<double> X{nb_free};
      Vector<double> b{nb_free};
      b = Mfixed * F;
      // std::cout << " Mfixed " << Mfixed << "\n";
      // std::cout << " F " << F << "\n";
      // std::cout << " b " << b << "\n";
      // std::cout << " Mfree " << Mfree << "\n";
      solver.solveLocalSystem(Mfree, X, b);
      parallel_for(
        m_given_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId node_id) {
          if (!is_fixed[node_id]) {
            pos_r[node_id][dir] = X[id_free[node_id]];
          }
        });
    }
    _correct_sym(m_given_mesh.xr(), pos_r);
    //    synchronize(pos_r);
    return pos_r;
  }

 public:
  std::shared_ptr<const IMesh>
  getSmoothedMesh() const
  {
    NodeValue<bool> is_displaced{m_given_mesh.connectivity()};
    is_displaced.fill(true);
    NodeValue<const Rd> given_xr = m_given_mesh.xr();

    NodeValue<Rd> xr = this->_getPosition(is_displaced);

    return std::make_shared<MeshType>(m_given_mesh.shared_connectivity(), xr);
  }

  std::shared_ptr<const IMesh>
  getSmoothedMesh(const FunctionSymbolId& function_symbol_id) const
  {
    NodeValue<const Rd> given_xr = m_given_mesh.xr();
    NodeValue<const bool> is_displaced =
      InterpolateItemValue<bool(const Rd)>::interpolate(function_symbol_id, given_xr);

    NodeValue<Rd> xr = this->_getPosition(is_displaced);

    return std::make_shared<MeshType>(m_given_mesh.shared_connectivity(), xr);
  }

  std::shared_ptr<const IMesh>
  getSmoothedMesh(
    const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
  {
    NodeValue<const Rd> given_xr = m_given_mesh.xr();
    auto node_to_cell_matrix     = m_given_mesh.connectivity().nodeToCellMatrix();

    NodeValue<bool> is_displaced{m_given_mesh.connectivity()};
    is_displaced.fill(false);
    for (size_t i_zone = 0; i_zone < discrete_function_variant_list.size(); ++i_zone) {
      auto is_zone_cell = discrete_function_variant_list[i_zone]->get<DiscreteFunctionP0<Dimension, const double>>();

      parallel_for(
        m_given_mesh.numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) {
          auto node_cell_list = node_to_cell_matrix[node_id];
          bool displace       = true;
          for (size_t i_node_cell = 0; i_node_cell < node_cell_list.size(); ++i_node_cell) {
            const CellId cell_id = node_cell_list[i_node_cell];
            displace &= (is_zone_cell[cell_id] != 0);
          }
          if (displace) {
            is_displaced[node_id] = true;
          }
        });
    }

    NodeValue<Rd> xr = this->_getPosition(is_displaced);

    return std::make_shared<MeshType>(m_given_mesh.shared_connectivity(), xr);
  }

  ImplicitMeshSmoother(const ImplicitMeshSmoother&) = delete;
  ImplicitMeshSmoother(ImplicitMeshSmoother&&)      = delete;

  ImplicitMeshSmoother(const MeshType& given_mesh,
                       const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
    : m_given_mesh(given_mesh), m_boundary_condition_list(this->_getBCList(given_mesh, bc_descriptor_list))
  {}

  ~ImplicitMeshSmoother() = default;
};

template <size_t Dimension>
class ImplicitMeshSmootherHandler::ImplicitMeshSmoother<Dimension>::FixedBoundaryCondition
{
 private:
  const MeshNodeBoundary<Dimension> m_mesh_node_boundary;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_node_boundary.nodeList();
  }

  FixedBoundaryCondition(MeshNodeBoundary<Dimension>&& mesh_node_boundary) : m_mesh_node_boundary{mesh_node_boundary} {}

  ~FixedBoundaryCondition() = default;
};

template <size_t Dimension>
class ImplicitMeshSmootherHandler::ImplicitMeshSmoother<Dimension>::SymmetryBoundaryCondition
{
 public:
  using Rd = TinyVector<Dimension, double>;

 private:
  const MeshFlatNodeBoundary<Dimension> m_mesh_flat_node_boundary;

 public:
  const Rd&
  outgoingNormal() const
  {
    return m_mesh_flat_node_boundary.outgoingNormal();
  }

  const Array<const NodeId>&
  nodeList() const
  {
    return m_mesh_flat_node_boundary.nodeList();
  }

  SymmetryBoundaryCondition(MeshFlatNodeBoundary<Dimension>&& mesh_flat_node_boundary)
    : m_mesh_flat_node_boundary(mesh_flat_node_boundary)
  {
    ;
  }

  ~SymmetryBoundaryCondition() = default;
};

std::shared_ptr<const IMesh>
ImplicitMeshSmootherHandler::getSmoothedMesh(
  const std::shared_ptr<const IMesh>& mesh,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list) const
{
  switch (mesh->dimension()) {
  case 1: {
    throw NotImplementedError("ImplicitMeshSmoother not implemented in 1D");
    break;
  }
  case 2: {
    constexpr size_t Dimension = 2;
    using MeshType             = Mesh<Connectivity<Dimension>>;
    ImplicitMeshSmoother smoother(dynamic_cast<const MeshType&>(*mesh), bc_descriptor_list);
    return smoother.getSmoothedMesh();
  }
  case 3: {
    constexpr size_t Dimension = 3;
    using MeshType             = Mesh<Connectivity<Dimension>>;
    ImplicitMeshSmoother smoother(dynamic_cast<const MeshType&>(*mesh), bc_descriptor_list);
    return smoother.getSmoothedMesh();
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}
std::shared_ptr<const IMesh>
ImplicitMeshSmootherHandler::getSmoothedMesh(
  const std::shared_ptr<const IMesh>& mesh,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& function_symbol_id) const
{
  switch (mesh->dimension()) {
  case 1: {
    throw NotImplementedError("ImplicitMeshSmoother not implemented in 1D");
    break;
  }
  case 2: {
    constexpr size_t Dimension = 2;
    using MeshType             = Mesh<Connectivity<Dimension>>;
    ImplicitMeshSmoother smoother(dynamic_cast<const MeshType&>(*mesh), bc_descriptor_list);
    return smoother.getSmoothedMesh(function_symbol_id);
  }
  case 3: {
    constexpr size_t Dimension = 3;
    using MeshType             = Mesh<Connectivity<Dimension>>;
    ImplicitMeshSmoother smoother(dynamic_cast<const MeshType&>(*mesh), bc_descriptor_list);
    return smoother.getSmoothedMesh(function_symbol_id);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

std::shared_ptr<const IMesh>
ImplicitMeshSmootherHandler::getSmoothedMesh(
  const std::shared_ptr<const IMesh>& mesh,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const std::vector<std::shared_ptr<const DiscreteFunctionVariant>>& discrete_function_variant_list) const
{
  if (not hasSameMesh(discrete_function_variant_list)) {
    throw NormalError("discrete functions are not defined on the same mesh");
  }

  std::shared_ptr<const IMesh> common_mesh = getCommonMesh(discrete_function_variant_list);

  if (common_mesh != mesh) {
    throw NormalError("discrete functions are not defined on the smoothed mesh");
  }

  switch (mesh->dimension()) {
  case 1: {
    throw NotImplementedError("ImplicitMeshSmoother not implemented in 1D");
    break;
  }
  case 2: {
    constexpr size_t Dimension = 2;
    using MeshType             = Mesh<Connectivity<Dimension>>;
    ImplicitMeshSmoother smoother(dynamic_cast<const MeshType&>(*mesh), bc_descriptor_list);
    return smoother.getSmoothedMesh(discrete_function_variant_list);
  }
  case 3: {
    constexpr size_t Dimension = 3;
    using MeshType             = Mesh<Connectivity<Dimension>>;
    ImplicitMeshSmoother smoother(dynamic_cast<const MeshType&>(*mesh), bc_descriptor_list);
    return smoother.getSmoothedMesh(discrete_function_variant_list);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}
