#ifndef MESH_DATA_HPP
#define MESH_DATA_HPP

#include <algebra/TinyVector.hpp>
#include <mesh/IMeshData.hpp>
#include <mesh/ItemValue.hpp>
#include <mesh/SubItemValuePerItem.hpp>
#include <utils/Exceptions.hpp>
#include <utils/Messenger.hpp>
#include <utils/PugsUtils.hpp>

#include <map>

template <size_t Dimension>
class Connectivity;

template <typename ConnectivityType>
class Mesh;

template <size_t Dimension>
class MeshData : public IMeshData
{
 public:
  static_assert(Dimension > 0, "dimension must be strictly positive");
  static_assert((Dimension <= 3), "only 1d, 2d and 3d are implemented");

  using MeshType = Mesh<Connectivity<Dimension>>;

  using Rd = TinyVector<Dimension>;

  static constexpr double inv_Dimension = 1. / Dimension;

 private:
  const MeshType& m_mesh;
  NodeValuePerFace<const Rd> m_Nlr;
  NodeValuePerFace<const Rd> m_nlr;
  NodeValuePerCell<const Rd> m_Cjr;
  NodeValuePerCell<const double> m_ljr;
  NodeValuePerCell<const Rd> m_njr;
  FaceValue<const Rd> m_xl;
  CellValue<const Rd> m_cell_centroid;
  CellValue<const Rd> m_cell_iso_barycenter;
  CellValue<const double> m_Vj;
  FaceValue<const double> m_ll;

  PUGS_INLINE
  void
  _compute_ll()
  {
    if constexpr (Dimension == 1) {
      static_assert(Dimension != 1, "ll does not make sense in 1d");
    } else {
      const auto& Nlr = this->Nlr();

      FaceValue<double> ll{m_mesh.connectivity()};

      const auto& face_to_node_matrix = m_mesh.connectivity().faceToNodeMatrix();
      parallel_for(
        m_mesh.numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
          const auto& face_nodes = face_to_node_matrix[face_id];

          double lenght = 0;
          for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
            lenght += l2Norm(Nlr(face_id, i_node));
          }

          ll[face_id] = lenght;
        });

      m_ll = ll;
    }
  }

  PUGS_INLINE
  void
  _compute_nlr()
  {
    if constexpr (Dimension == 1) {
      static_assert(Dimension != 1, "nlr does not make sense in 1d");
    } else {
      const auto& Nlr = this->Nlr();

      NodeValuePerFace<Rd> nlr{m_mesh.connectivity()};

      parallel_for(
        Nlr.numberOfValues(), PUGS_LAMBDA(size_t lr) {
          double length = l2Norm(Nlr[lr]);
          nlr[lr]       = 1. / length * Nlr[lr];
        });

      m_nlr = nlr;
    }
  }

  PUGS_INLINE
  void
  _computeFaceIsobarycenter()
  {   // Computes vertices isobarycenter
    if constexpr (Dimension == 1) {
      static_assert(Dimension != 1, "xl does not make sense in 1d");
    } else {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      const auto& face_to_node_matrix = m_mesh.connectivity().faceToNodeMatrix();
      FaceValue<Rd> xl(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
          Rd X                   = zero;
          const auto& face_nodes = face_to_node_matrix[face_id];
          for (size_t R = 0; R < face_nodes.size(); ++R) {
            X += xr[face_nodes[R]];
          }
          xl[face_id] = 1. / face_nodes.size() * X;
        });
      m_xl = xl;
    }
  }

  PUGS_INLINE
  void
  _computeCellIsoBarycenter()
  {   // Computes vertices isobarycenter
    if constexpr (Dimension == 1) {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

      CellValue<Rd> cell_iso_barycenter{m_mesh.connectivity()};
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          const auto& cell_nodes = cell_to_node_matrix[j];
          cell_iso_barycenter[j] = 0.5 * (xr[cell_nodes[0]] + xr[cell_nodes[1]]);
        });
      m_cell_iso_barycenter = cell_iso_barycenter;
    } else {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      const CellValue<const double>& inv_cell_nb_nodes = m_mesh.connectivity().invCellNbNodes();

      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
      CellValue<Rd> cell_iso_barycenter{m_mesh.connectivity()};
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          Rd X                   = zero;
          const auto& cell_nodes = cell_to_node_matrix[j];
          for (size_t R = 0; R < cell_nodes.size(); ++R) {
            X += xr[cell_nodes[R]];
          }
          cell_iso_barycenter[j] = inv_cell_nb_nodes[j] * X;
        });
      m_cell_iso_barycenter = cell_iso_barycenter;
    }
  }

  PUGS_INLINE
  void
  _computeCellCentroid()
  {
    const CellValue<const Rd> cell_iso_barycenter = this->cellIsoBarycenter();
    if constexpr (Dimension == 1) {
      m_cell_centroid = cell_iso_barycenter;
    } else {
      if constexpr (Dimension == 2) {
        const CellValue<const double> Vj = this->Vj();
        const NodeValue<const Rd>& xr    = m_mesh.xr();
        const auto& cell_to_node_matrix  = m_mesh.connectivity().cellToNodeMatrix();

        CellValue<Rd> cell_centroid{m_mesh.connectivity()};
        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            Rd sum = zero;

            const auto& cell_nodes = cell_to_node_matrix[j];

            for (size_t R = 0; R < cell_nodes.size(); ++R) {
              const Rd& xr0 = xr[cell_nodes[R]];
              const Rd& xr1 = xr[cell_nodes[(R + 1) % cell_nodes.size()]];

              Rd xjxr0 = xr[cell_nodes[R]] - cell_iso_barycenter[j];
              Rd xjxr1 = xr[cell_nodes[(R + 1) % cell_nodes.size()]] - cell_iso_barycenter[j];

              const double Vjl = 0.5 * (xjxr0[0] * xjxr1[1] - xjxr0[1] * xjxr1[0]);

              sum += Vjl * (xr0 + xr1 + cell_iso_barycenter[j]);
            }

            sum *= 1 / (3 * Vj[j]);

            cell_centroid[j] = sum;
          });
        m_cell_centroid = cell_centroid;
      } else {
        const auto& face_center           = this->xl();
        const CellValue<const double> Vj  = this->Vj();
        const NodeValue<const Rd>& xr     = m_mesh.xr();
        const auto& cell_to_face_matrix   = m_mesh.connectivity().cellToFaceMatrix();
        const auto& face_to_node_matrix   = m_mesh.connectivity().faceToNodeMatrix();
        const auto& cell_face_is_reversed = m_mesh.connectivity().cellFaceIsReversed();

        CellValue<Rd> cell_centroid{m_mesh.connectivity()};
        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
            const Rd xj = m_cell_iso_barycenter[j];

            const auto& cell_faces = cell_to_face_matrix[j];

            Rd sum = zero;
            for (size_t i_face = 0; i_face < cell_faces.size(); ++i_face) {
              const FaceId face_id = cell_faces[i_face];

              const Rd xl = face_center[face_id];

              const Rd xjxl = xl - xj;

              const auto& face_nodes = face_to_node_matrix[face_id];

              const Rd xl_plus_xj = xl + xj;

              double sign = (cell_face_is_reversed(j, i_face)) ? -1 : 1;

              for (size_t i_face_node = 0; i_face_node < face_nodes.size(); ++i_face_node) {
                const Rd& xr0 = xr[face_nodes[i_face_node]];
                const Rd& xr1 = xr[face_nodes[(i_face_node + 1) % face_nodes.size()]];

                const Rd xjxr0 = xr0 - xj;
                const Rd xjxr1 = xr1 - xj;

                const double Vjlr = (crossProduct(xjxr0, xjxr1), xjxl);

                sum += (sign * Vjlr) * (xl_plus_xj + xr0 + xr1);
              }
            }

            sum *= 1 / (24 * Vj[j]);

            cell_centroid[j] = sum;
          });

        m_cell_centroid = cell_centroid;
      }
    }
  }

  PUGS_INLINE
  void
  _computeCellVolume()
  {
    const NodeValue<const Rd>& xr   = m_mesh.xr();
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

    auto Cjr = this->Cjr();

    CellValue<double> Vj(m_mesh.connectivity());
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        double sum_cjr_xr      = 0;
        const auto& cell_nodes = cell_to_node_matrix[j];

        for (size_t R = 0; R < cell_nodes.size(); ++R) {
          sum_cjr_xr += (xr[cell_nodes[R]], Cjr(j, R));
        }
        Vj[j] = inv_Dimension * sum_cjr_xr;
      });
    m_Vj = Vj;
  }

  PUGS_INLINE
  void
  _computeNlr()
  {
    if constexpr (Dimension == 1) {
      static_assert(Dimension != 1, "Nlr does not make sense in 1d");
    } else if constexpr (Dimension == 2) {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      NodeValuePerFace<Rd> Nlr(m_mesh.connectivity());
      const auto& face_to_node_matrix = m_mesh.connectivity().faceToNodeMatrix();

      parallel_for(
        m_mesh.numberOfFaces(), PUGS_LAMBDA(FaceId l) {
          const auto& face_nodes = face_to_node_matrix[l];

          const Rd xr0 = xr[face_nodes[0]];
          const Rd xr1 = xr[face_nodes[1]];
          const Rd dx  = xr1 - xr0;

          const Rd Nr = 0.5 * Rd{dx[1], -dx[0]};

          Nlr(l, 0) = Nr;
          Nlr(l, 1) = Nr;
        });
      m_Nlr = Nlr;
    } else {
      const NodeValue<const Rd>& xr = m_mesh.xr();

      NodeValuePerFace<Rd> Nlr(m_mesh.connectivity());
      const auto& face_to_node_matrix = m_mesh.connectivity().faceToNodeMatrix();

      parallel_for(
        m_mesh.numberOfFaces(), PUGS_LAMBDA(FaceId l) {
          const auto& face_nodes = face_to_node_matrix[l];
          const size_t nb_nodes  = face_nodes.size();
          std::vector<Rd> dxr(nb_nodes);
          for (size_t r = 0; r < nb_nodes; ++r) {
            dxr[r] = xr[face_nodes[(r + 1) % nb_nodes]] - xr[face_nodes[(r + nb_nodes - 1) % nb_nodes]];
          }
          const double inv_12_nb_nodes = 1. / (12. * nb_nodes);
          for (size_t r = 0; r < nb_nodes; ++r) {
            Rd Nr            = zero;
            const Rd two_dxr = 2 * dxr[r];
            for (size_t s = 0; s < nb_nodes; ++s) {
              Nr += crossProduct((two_dxr - dxr[s]), xr[face_nodes[s]]);
            }
            Nr *= inv_12_nb_nodes;
            Nr -= (1. / 6.) * crossProduct(dxr[r], xr[face_nodes[r]]);
            Nlr(l, r) = Nr;
          }
        });
      m_Nlr = Nlr;
    }
  }

  PUGS_INLINE
  void
  _computeCjr()
  {
    if constexpr (Dimension == 1) {
      NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          Cjr(j, 0) = -1;
          Cjr(j, 1) = 1;
        });
      m_Cjr = Cjr;
    } else if constexpr (Dimension == 2) {
      const NodeValue<const Rd>& xr   = m_mesh.xr();
      const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();

      NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          const auto& cell_nodes = cell_to_node_matrix[j];
          for (size_t R = 0; R < cell_nodes.size(); ++R) {
            int Rp1         = (R + 1) % cell_nodes.size();
            int Rm1         = (R + cell_nodes.size() - 1) % cell_nodes.size();
            Rd half_xrp_xrm = 0.5 * (xr[cell_nodes[Rp1]] - xr[cell_nodes[Rm1]]);
            Cjr(j, R)       = Rd{half_xrp_xrm[1], -half_xrp_xrm[0]};
          }
        });
      m_Cjr = Cjr;
    } else if (Dimension == 3) {
      auto Nlr = this->Nlr();

      const auto& face_to_node_matrix   = m_mesh.connectivity().faceToNodeMatrix();
      const auto& cell_to_node_matrix   = m_mesh.connectivity().cellToNodeMatrix();
      const auto& cell_to_face_matrix   = m_mesh.connectivity().cellToFaceMatrix();
      const auto& cell_face_is_reversed = m_mesh.connectivity().cellFaceIsReversed();

      NodeValuePerCell<Rd> Cjr(m_mesh.connectivity());
      parallel_for(
        Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { Cjr[jr] = zero; });

      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
          const auto& cell_nodes = cell_to_node_matrix[j];

          const auto& cell_faces       = cell_to_face_matrix[j];
          const auto& face_is_reversed = cell_face_is_reversed.itemValues(j);

          for (size_t L = 0; L < cell_faces.size(); ++L) {
            const FaceId& l        = cell_faces[L];
            const auto& face_nodes = face_to_node_matrix[l];

            auto local_node_number_in_cell = [&](NodeId node_number) {
              for (size_t i_node = 0; i_node < cell_nodes.size(); ++i_node) {
                if (node_number == cell_nodes[i_node]) {
                  return i_node;
                }
              }
              return std::numeric_limits<size_t>::max();
            };

            if (face_is_reversed[L]) {
              for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
                const size_t R = local_node_number_in_cell(face_nodes[rl]);
                Cjr(j, R) -= Nlr(l, rl);
              }
            } else {
              for (size_t rl = 0; rl < face_nodes.size(); ++rl) {
                const size_t R = local_node_number_in_cell(face_nodes[rl]);
                Cjr(j, R) += Nlr(l, rl);
              }
            }
          }
        });

      m_Cjr = Cjr;
    }
  }

  PUGS_INLINE
  void
  _compute_ljr()
  {
    auto Cjr = this->Cjr();
    if constexpr (Dimension == 1) {
      NodeValuePerCell<double> ljr(m_mesh.connectivity());
      parallel_for(
        ljr.numberOfValues(), PUGS_LAMBDA(size_t jr) { ljr[jr] = 1; });
      m_ljr = ljr;

    } else {
      NodeValuePerCell<double> ljr(m_mesh.connectivity());
      parallel_for(
        Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { ljr[jr] = l2Norm(Cjr[jr]); });
      m_ljr = ljr;
    }
  }

  PUGS_INLINE
  void
  _compute_njr()
  {
    auto Cjr = this->Cjr();
    if constexpr (Dimension == 1) {
      // in 1d njr=Cjr (here performs a shallow copy)
      m_njr = m_Cjr;
    } else {
      auto ljr = this->ljr();

      NodeValuePerCell<Rd> njr(m_mesh.connectivity());
      parallel_for(
        Cjr.numberOfValues(), PUGS_LAMBDA(size_t jr) { njr[jr] = (1. / ljr[jr]) * Cjr[jr]; });
      m_njr = njr;
    }
  }

  void
  _checkCellVolume()
  {
    Assert(m_Vj.isBuilt());

    bool is_valid = [&] {
      for (CellId j = 0; j < m_mesh.numberOfCells(); ++j) {
        if (m_Vj[j] <= 0) {
          return false;
        }
      }
      return true;
    }();

    if (not parallel::allReduceAnd(is_valid)) {
      throw NormalError("mesh contains cells of non-positive volume");
    }
  }

 public:
  PUGS_INLINE
  const MeshType&
  mesh() const
  {
    return m_mesh;
  }

  PUGS_INLINE
  FaceValue<const double>
  ll()
  {
    if (not m_ll.isBuilt()) {
      this->_compute_ll();
    }
    return m_ll;
  }

  PUGS_INLINE
  NodeValuePerFace<const Rd>
  Nlr()
  {
    if (not m_Nlr.isBuilt()) {
      this->_computeNlr();
    }
    return m_Nlr;
  }

  PUGS_INLINE
  NodeValuePerFace<const Rd>
  nlr()
  {
    if (not m_nlr.isBuilt()) {
      this->_compute_nlr();
    }
    return m_nlr;
  }

  PUGS_INLINE
  NodeValuePerCell<const Rd>
  Cjr()
  {
    if (not m_Cjr.isBuilt()) {
      this->_computeCjr();
    }
    return m_Cjr;
  }

  PUGS_INLINE
  NodeValuePerCell<const double>
  ljr()
  {
    if (not m_ljr.isBuilt()) {
      this->_compute_ljr();
    }
    return m_ljr;
  }

  PUGS_INLINE
  NodeValuePerCell<const Rd>
  njr()
  {
    if (not m_njr.isBuilt()) {
      this->_compute_njr();
    }
    return m_njr;
  }

  PUGS_INLINE
  FaceValue<const Rd>
  xl()
  {
    if (not m_xl.isBuilt()) {
      this->_computeFaceIsobarycenter();
    }
    return m_xl;
  }

  PUGS_INLINE
  CellValue<const Rd>
  cellIsoBarycenter()
  {
    if (not m_cell_iso_barycenter.isBuilt()) {
      this->_computeCellIsoBarycenter();
    }
    return m_cell_iso_barycenter;
  }

  PUGS_INLINE
  CellValue<const Rd>
  xj()
  {
    if (not m_cell_centroid.isBuilt()) {
      this->_computeCellCentroid();
    }
    return m_cell_centroid;
  }

  PUGS_INLINE
  CellValue<const double>
  Vj()
  {
    if (not m_Vj.isBuilt()) {
      this->_computeCellVolume();
      this->_checkCellVolume();
    }
    return m_Vj;
  }

 private:
  // MeshData **must** be constructed through MeshDataManager
  friend class MeshDataManager;
  MeshData(const MeshType& mesh) : m_mesh(mesh) {}

 public:
  MeshData() = delete;

  MeshData(const MeshData&) = delete;
  MeshData(MeshData&&)      = delete;

  ~MeshData() {}
};

#endif   // MESH_DATA_HPP