#include <language/algorithms/Heat5PointsAlgorithm.hpp>

#include <algebra/CRSMatrix.hpp>
#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LeastSquareSolver.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/LinearSolverOptions.hpp>
#include <algebra/SmallMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <algebra/Vector.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/DualConnectivityManager.hpp>
#include <mesh/DualMeshManager.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/PrimalToDiamondDualConnectivityDataMapper.hpp>

template <size_t Dimension>
Heat5PointsAlgorithm<Dimension>::Heat5PointsAlgorithm(
  std::shared_ptr<const IMesh> i_mesh,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& T_id,
  const FunctionSymbolId& kappa_id,
  const FunctionSymbolId& f_id)
{
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

  std::cout << "number of bc descr = " << bc_descriptor_list.size() << '\n';

  if constexpr (Dimension == 2) {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    CellValue<double> Tj =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(T_id, mesh_data.xj());

    NodeValue<double> Tr(mesh->connectivity());
    const NodeValue<const TinyVector<Dimension>>& xr = mesh->xr();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
    const auto& node_to_cell_matrix                  = mesh->connectivity().nodeToCellMatrix();
    CellValuePerNode<double> w_rj{mesh->connectivity()};

    for (NodeId i_node = 0; i_node < mesh->numberOfNodes(); ++i_node) {
      SmallVector<double> b{Dimension + 1};
      b[0] = 1;
      for (size_t i = 1; i < Dimension + 1; i++) {
        b[i] = xr[i_node][i - 1];
      }
      const auto& node_to_cell = node_to_cell_matrix[i_node];

      SmallMatrix<double> A{Dimension + 1, node_to_cell.size()};
      for (size_t j = 0; j < node_to_cell.size(); j++) {
        A(0, j) = 1;
      }
      for (size_t i = 1; i < Dimension + 1; i++) {
        for (size_t j = 0; j < node_to_cell.size(); j++) {
          const CellId J = node_to_cell[j];
          A(i, j)        = xj[J][i - 1];
        }
      }
      SmallVector<double> x{node_to_cell.size()};
      x = zero;

      LeastSquareSolver ls_solver;
      ls_solver.solveLocalSystem(A, x, b);

      Tr[i_node] = 0;
      for (size_t j = 0; j < node_to_cell.size(); j++) {
        Tr[i_node] += x[j] * Tj[node_to_cell[j]];
        w_rj(i_node, j) = x[j];
      }
    }

    {
      std::shared_ptr diamond_mesh = DualMeshManager::instance().getDiamondDualMesh(*mesh);

      MeshDataType& diamond_mesh_data = MeshDataManager::instance().getMeshData(*diamond_mesh);

      std::shared_ptr mapper =
        DualConnectivityManager::instance().getPrimalToDiamondDualConnectivityDataMapper(mesh->connectivity());

      NodeValue<double> Trd{diamond_mesh->connectivity()};

      mapper->toDualNode(Tr, Tj, Trd);

      CellValue<double> kappaj =
        InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(kappa_id,
                                                                                                  mesh_data.xj());

      CellValue<double> dual_kappaj =
        InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(kappa_id,
                                                                                                  diamond_mesh_data
                                                                                                    .xj());

      const CellValue<const double> dual_Vj = diamond_mesh_data.Vj();

      const FaceValue<const double> mes_l = [&] {
        if constexpr (Dimension == 1) {
          FaceValue<double> compute_mes_l{mesh->connectivity()};
          compute_mes_l.fill(1);
          return compute_mes_l;
        } else {
          return mesh_data.ll();
        }
      }();

      const CellValue<const double> dual_mes_l_j = [=] {
        CellValue<double> compute_mes_j{diamond_mesh->connectivity()};
        mapper->toDualCell(mes_l, compute_mes_j);

        return compute_mes_j;
      }();

      FaceValue<const double> alpha_l = [&] {
        CellValue<double> alpha_j{diamond_mesh->connectivity()};

        parallel_for(
          diamond_mesh->numberOfCells(), PUGS_LAMBDA(CellId diamond_cell_id) {
            alpha_j[diamond_cell_id] =
              dual_mes_l_j[diamond_cell_id] * dual_kappaj[diamond_cell_id] / dual_Vj[diamond_cell_id];
          });

        FaceValue<double> computed_alpha_l{mesh->connectivity()};
        mapper->fromDualCell(alpha_j, computed_alpha_l);

        return computed_alpha_l;
      }();

      const Array<int> non_zeros{mesh->numberOfCells()};
      non_zeros.fill(Dimension);
      CRSMatrixDescriptor<double> S(mesh->numberOfCells(), mesh->numberOfCells(), non_zeros);

      const auto& face_to_cell_matrix = mesh->connectivity().faceToCellMatrix();
      for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
        const auto& primal_face_to_cell = face_to_cell_matrix[face_id];

        const double beta_l = 0.5 * alpha_l[face_id] * mes_l[face_id];

        for (size_t i_cell = 0; i_cell < primal_face_to_cell.size(); ++i_cell) {
          const CellId cell_id1 = primal_face_to_cell[i_cell];
          for (size_t j_cell = 0; j_cell < primal_face_to_cell.size(); ++j_cell) {
            const CellId cell_id2 = primal_face_to_cell[j_cell];
            if (i_cell == j_cell) {
              S(cell_id1, cell_id2) -= beta_l;
            } else {
              S(cell_id1, cell_id2) += beta_l;
            }
          }
        }
      }
      CellValue<double> fj =
        InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(f_id, mesh_data.xj());

      const CellValue<const double> primal_Vj = mesh_data.Vj();
      CRSMatrix A{S.getCRSMatrix()};
      Vector<double> b{mesh->numberOfCells()};
      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { b[cell_id] = fj[cell_id] * primal_Vj[cell_id]; });

      Vector<double> T{mesh->numberOfCells()};
      T = zero;

      LinearSolver solver;
      solver.solveLocalSystem(A, T, b);

      CellValue<double> Temperature{mesh->connectivity()};

      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { Temperature[cell_id] = T[cell_id]; });

      Vector<double> error{mesh->numberOfCells()};
      parallel_for(
        mesh->numberOfCells(),
        PUGS_LAMBDA(CellId cell_id) { error[cell_id] = (Temperature[cell_id] - Tj[cell_id]) * primal_Vj[cell_id]; });

      std::cout << "||Error||_2 = " << std::sqrt(dot(error, error)) << "\n";
    }
  } else {
    throw NotImplementedError("not done in this dimension");
  }
}

template Heat5PointsAlgorithm<1>::Heat5PointsAlgorithm(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);

template Heat5PointsAlgorithm<2>::Heat5PointsAlgorithm(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);

template Heat5PointsAlgorithm<3>::Heat5PointsAlgorithm(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);