#ifndef MESH_DATA_HPP
#define MESH_DATA_HPP

#include <algebra/TinyVector.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/SubItemValuePerItem.hpp>
#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>
#include <utils/PugsUtils.hpp>

#include <map>

template <typename M>
class MeshData
{
 public:
  using MeshType = M;

  static constexpr size_t Dimension = MeshType::Dimension;
  static_assert(Dimension > 0, "dimension must be strictly positive");

  using Rd = TinyVector<Dimension>;

  static constexpr double inv_Dimension = 1. / Dimension;

 private:
  const MeshType& m_mesh;
  NodeValuePerCell<const Rd> m_Cjr;
  NodeValuePerCell<const double> m_ljr;
  NodeValuePerCell<const Rd> m_njr;
  CellValue<const Rd> m_xj;
  CellValue<const double> m_Vj;

  PUGS_INLINE
  void
  _updateCenter()
  {   // Computes vertices isobarycenter
    if constexpr (Dimension == 1) {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

      CellValue<Rd> xj(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          const auto& cell_nodes = cell_to_node_matrix[j];
          xj[j]                  = 0.5 * (xr[cell_nodes[0]] + xr[cell_nodes[1]]);
        });
      m_xj = xj;
    } else {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      const CellValue<const double>& inv_cell_nb_nodes = m_mesh.connectivity().invCellNbNodes();

      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
      CellValue<Rd> xj(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          Rd X                   = zero;
          const auto& cell_nodes = cell_to_node_matrix[j];
          for (size_t R = 0; R < cell_nodes.size(); ++R) {
            X += xr[cell_nodes[R]];
          }
          xj[j] = inv_cell_nb_nodes[j] * X;
        });
      m_xj = xj;
    }
  }

  PUGS_INLINE
  void
  _updateVolume()
  {
    const NodeValue<const Rd>& xr   = m_mesh.xr();
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

    CellValue<double> Vj(m_mesh.connectivity());
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        double sum_cjr_xr      = 0;
        const auto& cell_nodes = cell_to_node_matrix[j];

        for (size_t R = 0; R < cell_nodes.size(); ++R) {
          sum_cjr_xr += (xr[cell_nodes[R]], m_Cjr(j, R));
        }
        Vj[j] = inv_Dimension * sum_cjr_xr;
      });
    m_Vj = Vj;
  }

  PUGS_INLINE
  void
  _updateCjr()
  {
    if constexpr (Dimension == 1) {
      // Cjr/njr/ljr are constant overtime
    } else if constexpr (Dimension == 2) {
      const NodeValue<const Rd>& xr   = m_mesh.xr();
      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

      {
        NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            const auto& cell_nodes = cell_to_node_matrix[j];
            for (size_t R = 0; R < cell_nodes.size(); ++R) {
              int Rp1         = (R + 1) % cell_nodes.size();
              int Rm1         = (R + cell_nodes.size() - 1) % cell_nodes.size();
              Rd half_xrp_xrm = 0.5 * (xr[cell_nodes[Rp1]] - xr[cell_nodes[Rm1]]);
              Cjr(j, R)       = Rd{half_xrp_xrm[1], -half_xrp_xrm[0]};
            }
          });
        m_Cjr = Cjr;
      }

      {
        NodeValuePerCell<double> ljr(m_mesh.connectivity());
        parallel_for(
          m_Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { ljr[jr] = l2Norm(m_Cjr[jr]); });
        m_ljr = ljr;
      }

      {
        NodeValuePerCell<Rd> njr(m_mesh.connectivity());
        parallel_for(
          m_Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { njr[jr] = (1. / m_ljr[jr]) * m_Cjr[jr]; });
        m_njr = njr;
      }
    } else if (Dimension == 3) {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      NodeValuePerFace<Rd> Nlr(m_mesh.connectivity());
      const auto& face_to_node_matrix = m_mesh.connectivity().faceToNodeMatrix();

      parallel_for(
        m_mesh.numberOfFaces(), PUGS_LAMBDA(FaceId l) {
          const auto& face_nodes = face_to_node_matrix[l];
          const size_t nb_nodes  = face_nodes.size();
          std::vector<Rd> dxr(nb_nodes);
          for (size_t r = 0; r < nb_nodes; ++r) {
            dxr[r] = xr[face_nodes[(r + 1) % nb_nodes]] - xr[face_nodes[(r + nb_nodes - 1) % nb_nodes]];
          }
          const double inv_12_nb_nodes = 1. / (12. * nb_nodes);
          for (size_t r = 0; r < nb_nodes; ++r) {
            Rd Nr            = zero;
            const Rd two_dxr = 2 * dxr[r];
            for (size_t s = 0; s < nb_nodes; ++s) {
              Nr += crossProduct((two_dxr - dxr[s]), xr[face_nodes[s]]);
            }
            Nr *= inv_12_nb_nodes;
            Nr -= (1. / 6.) * crossProduct(dxr[r], xr[face_nodes[r]]);
            Nlr(l, r) = Nr;
          }
        });

      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

      const auto& cell_to_face_matrix = m_mesh.connectivity().cellToFaceMatrix();

      const auto& cell_face_is_reversed = m_mesh.connectivity().cellFaceIsReversed();

      {
        NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
        parallel_for(
          Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { Cjr[jr] = zero; });

        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            const auto& cell_nodes = cell_to_node_matrix[j];

            const auto& cell_faces       = cell_to_face_matrix[j];
            const auto& face_is_reversed = cell_face_is_reversed.itemValues(j);

            for (size_t L = 0; L < cell_faces.size(); ++L) {
              const FaceId& l        = cell_faces[L];
              const auto& face_nodes = face_to_node_matrix[l];

              auto local_node_number_in_cell = [&](NodeId node_number) {
                for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
                  if (node_number == cell_nodes[i_node]) {
                    return i_node;
                  }
                }
                return std::numeric_limits<size_t>::max();
              };

              if (face_is_reversed[L]) {
                for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
                  const size_t R = local_node_number_in_cell(face_nodes[rl]);
                  Cjr(j, R) -= Nlr(l, rl);
                }
              } else {
                for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
                  const size_t R = local_node_number_in_cell(face_nodes[rl]);
                  Cjr(j, R) += Nlr(l, rl);
                }
              }
            }
          });

        m_Cjr = Cjr;
      }

      {
        NodeValuePerCell<double> ljr(m_mesh.connectivity());
        parallel_for(
          m_Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { ljr[jr] = l2Norm(m_Cjr[jr]); });
        m_ljr = ljr;
      }

      {
        NodeValuePerCell<Rd> njr(m_mesh.connectivity());
        parallel_for(
          m_Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { njr[jr] = (1. / m_ljr[jr]) * m_Cjr[jr]; });
        m_njr = njr;
      }
    }
    static_assert((Dimension <= 3), "only 1d, 2d and 3d are implemented");
  }

  void
  _checkCellVolume() const
  {
    bool is_valid = [&] {
      for (CellId j = 0; j < m_mesh.numberOfCells(); ++j) {
        if (m_Vj[j] <= 0) {
          return false;
        }
      }
      return true;
    }();

    if (not parallel::allReduceAnd(is_valid)) {
      throw NormalError("mesh contains cells of non-positive volume");
    }
  }

 public:
  PUGS_INLINE
  const MeshType&
  mesh() const
  {
    return m_mesh;
  }

  PUGS_INLINE
  const NodeValuePerCell<const Rd>&
  Cjr() const
  {
    return m_Cjr;
  }

  PUGS_INLINE
  const NodeValuePerCell<const double>&
  ljr() const
  {
    return m_ljr;
  }

  PUGS_INLINE
  const NodeValuePerCell<const Rd>&
  njr() const
  {
    return m_njr;
  }

  PUGS_INLINE
  const CellValue<const Rd>&
  xj() const
  {
    return m_xj;
  }

  PUGS_INLINE
  const CellValue<const double>&
  Vj() const
  {
    return m_Vj;
  }

  void
  updateAllData()
  {
    this->_updateCjr();
    this->_updateCenter();
    this->_updateVolume();
    this->_checkCellVolume();
  }

  MeshData(const MeshType& mesh) : m_mesh(mesh)
  {
    if constexpr (Dimension == 1) {
      // in 1d Cjr are computed once for all
      {
        NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            Cjr(j, 0) = -1;
            Cjr(j, 1) = 1;
          });
        m_Cjr = Cjr;
      }
      // in 1d njr=Cjr (here performs a shallow copy)
      m_njr = m_Cjr;
      {
        NodeValuePerCell<double> ljr(m_mesh.connectivity());
        parallel_for(
          ljr.numberOfValues(), PUGS_LAMBDA(size_t jr) { ljr[jr] = 1; });
        m_ljr = ljr;
      }
    }
    this->updateAllData();
  }

  MeshData(const MeshData&) = delete;
  MeshData(MeshData&&)      = delete;

  ~MeshData()
  {
    ;
  }
};

#endif   // MESH_DATA_HPP
