#include <algebra/LeastSquareSolver.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/SmallMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <algebra/Vector.hpp>
#include <language/algorithms/ParabolicHeat.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/DualConnectivityManager.hpp>
#include <mesh/DualMeshManager.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <mesh/PrimalToDiamondDualConnectivityDataMapper.hpp>
#include <mesh/SubItemValuePerItem.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/FourierBoundaryConditionDescriptor.hpp>
#include <scheme/NeumannBoundaryConditionDescriptor.hpp>
#include <scheme/ScalarDiamondScheme.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
#include <utils/Timer.hpp>

template <size_t Dimension>
ParabolicHeatScheme<Dimension>::ParabolicHeatScheme(
  std::shared_ptr<const IMesh> i_mesh,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& T_id,
  const FunctionSymbolId& T_init_id,
  const FunctionSymbolId& kappa_id,
  const FunctionSymbolId& f_id,
  const double& Tf,
  const double& dt)
{
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;

  using BoundaryCondition = std::variant<DirichletBoundaryCondition, FourierBoundaryCondition, NeumannBoundaryCondition,
                                         SymmetryBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList boundary_condition_list;

  std::cout << "number of bc descr = " << bc_descriptor_list.size() << '\n';
  std::shared_ptr mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

  for (const auto& bc_descriptor : bc_descriptor_list) {
    bool is_valid_boundary_condition = true;

    switch (bc_descriptor->type()) {
    case IBoundaryConditionDescriptor::Type::symmetry: {
      throw NotImplementedError("NIY");
      break;
    }
    case IBoundaryConditionDescriptor::Type::dirichlet: {
      const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
        dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
      if constexpr (Dimension > 1) {
        MeshFaceBoundary<Dimension> mesh_face_boundary =
          getMeshFaceBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

        const FunctionSymbolId g_id = dirichlet_bc_descriptor.rhsSymbolId();
        MeshDataType& mesh_data     = MeshDataManager::instance().getMeshData(*mesh);

        Array<const double> value_list =
          InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                    mesh_data.xl(),
                                                                                                    mesh_face_boundary
                                                                                                      .faceList());
        boundary_condition_list.push_back(DirichletBoundaryCondition{mesh_face_boundary.faceList(), value_list});
      } else {
        throw NotImplementedError("not implemented in 1d");
      }
      break;
    }
    case IBoundaryConditionDescriptor::Type::fourier: {
      throw NotImplementedError("NIY");
      break;
    }
    case IBoundaryConditionDescriptor::Type::neumann: {
      const NeumannBoundaryConditionDescriptor& neumann_bc_descriptor =
        dynamic_cast<const NeumannBoundaryConditionDescriptor&>(*bc_descriptor);

      if constexpr (Dimension > 1) {
        MeshFaceBoundary<Dimension> mesh_face_boundary =
          getMeshFaceBoundary(*mesh, neumann_bc_descriptor.boundaryDescriptor());

        const FunctionSymbolId g_id = neumann_bc_descriptor.rhsSymbolId();

        MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

        Array<const double> value_list =
          InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(g_id,
                                                                                                    mesh_data.xl(),
                                                                                                    mesh_face_boundary
                                                                                                      .faceList());
        boundary_condition_list.push_back(NeumannBoundaryCondition{mesh_face_boundary.faceList(), value_list});
      } else {
        throw NotImplementedError("not implemented in 1d");
      }

      break;
    }
    default: {
      is_valid_boundary_condition = false;
    }
    }
    if (not is_valid_boundary_condition) {
      std::ostringstream error_msg;
      error_msg << *bc_descriptor << " is an invalid boundary condition for heat equation";
      throw NormalError(error_msg.str());
    }
  }

  if constexpr (Dimension > 1) {
    const CellValue<const size_t> cell_dof_number = [&] {
      CellValue<size_t> compute_cell_dof_number{mesh->connectivity()};
      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) { compute_cell_dof_number[cell_id] = cell_id; });
      return compute_cell_dof_number;
    }();
    size_t number_of_dof = mesh->numberOfCells();

    const FaceValue<const size_t> face_dof_number = [&] {
      FaceValue<size_t> compute_face_dof_number{mesh->connectivity()};
      compute_face_dof_number.fill(std::numeric_limits<size_t>::max());
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr ((std::is_same_v<T, NeumannBoundaryCondition>) or
                          (std::is_same_v<T, FourierBoundaryCondition>) or
                          (std::is_same_v<T, SymmetryBoundaryCondition>) or
                          (std::is_same_v<T, DirichletBoundaryCondition>)) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id = face_list[i_face];
                if (compute_face_dof_number[face_id] != std::numeric_limits<size_t>::max()) {
                  std::ostringstream os;
                  os << "The face " << face_id << " is used at least twice for boundary conditions";
                  throw NormalError(os.str());
                } else {
                  compute_face_dof_number[face_id] = number_of_dof++;
                }
              }
            }
          },
          boundary_condition);
      }

      return compute_face_dof_number;
    }();

    const FaceValue<const bool> primal_face_is_neumann = [&] {
      FaceValue<bool> face_is_neumann{mesh->connectivity()};
      face_is_neumann.fill(false);
      for (const auto& boundary_condition : boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr ((std::is_same_v<T, NeumannBoundaryCondition>) or
                          (std::is_same_v<T, FourierBoundaryCondition>) or
                          (std::is_same_v<T, SymmetryBoundaryCondition>)) {
              const auto& face_list = bc.faceList();

              for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                const FaceId face_id     = face_list[i_face];
                face_is_neumann[face_id] = true;
              }
            }
          },
          boundary_condition);
      }

      return face_is_neumann;
    }();

    const auto& primal_face_to_node_matrix = mesh->connectivity().faceToNodeMatrix();
    const auto& face_to_cell_matrix        = mesh->connectivity().faceToCellMatrix();
    NodeValue<bool> primal_node_is_on_boundary(mesh->connectivity());
    if (parallel::size() > 1) {
      throw NotImplementedError("Calculation of node_is_on_boundary is incorrect");
    }

    primal_node_is_on_boundary.fill(false);
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      if (face_to_cell_matrix[face_id].size() == 1) {
        for (size_t i_node = 0; i_node < primal_face_to_node_matrix[face_id].size(); ++i_node) {
          NodeId node_id                      = primal_face_to_node_matrix[face_id][i_node];
          primal_node_is_on_boundary[node_id] = true;
        }
      }
    }

    FaceValue<bool> primal_face_is_on_boundary(mesh->connectivity());
    if (parallel::size() > 1) {
      throw NotImplementedError("Calculation of face_is_on_boundary is incorrect");
    }

    primal_face_is_on_boundary.fill(false);
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      if (face_to_cell_matrix[face_id].size() == 1) {
        primal_face_is_on_boundary[face_id] = true;
      }
    }

    FaceValue<bool> primal_face_is_dirichlet(mesh->connectivity());
    if (parallel::size() > 1) {
      throw NotImplementedError("Calculation of face_is_neumann is incorrect");
    }

    primal_face_is_dirichlet.fill(false);
    for (FaceId face_id = 0; face_id < mesh->numberOfFaces(); ++face_id) {
      primal_face_is_dirichlet[face_id] = (primal_face_is_on_boundary[face_id] && (!primal_face_is_neumann[face_id]));
    }
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    CellValue<double> Tj =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(T_id, mesh_data.xj());
    CellValue<double> Temperature =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(T_init_id,
                                                                                                mesh_data.xj());
    //{mesh->connectivity()};
    FaceValue<double> Tl =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(T_id, mesh_data.xl());
    FaceValue<double> Temperature_face =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::face>(T_init_id,
                                                                                                mesh_data.xl());

    double time = 0;
    Assert(dt > 0, "The time-step have to be positive!");
    const CellValue<const double> primal_Vj = mesh_data.Vj();

    InterpolationWeightsManager iwm(mesh, primal_face_is_on_boundary, primal_node_is_on_boundary);
    iwm.compute();
    CellValuePerNode<double> w_rj = iwm.wrj();
    FaceValuePerNode<double> w_rl = iwm.wrl();
    do {
      double deltat = std::min(dt, Tf - time);
      std::cout << "Current time = " << time << " time-step = " << deltat << " final time = " << Tf << "\n";
      LegacyScalarDiamondScheme<Dimension>(i_mesh, bc_descriptor_list, kappa_id, f_id, Temperature, Temperature_face,
                                           Tf, deltat, w_rj, w_rl);
      time += deltat;
    } while (time < Tf && std::abs(time - Tf) > 1e-15);
    {
      Vector<double> error{mesh->numberOfCells()};
      CellValue<double> cell_error{mesh->connectivity()};
      Vector<double> face_error{mesh->numberOfFaces()};
      double error_max                    = 0.;
      size_t cell_max                     = 0;
      const FaceValue<const double> mes_l = [&] {
        if constexpr (Dimension == 1) {
          FaceValue<double> compute_mes_l{mesh->connectivity()};
          compute_mes_l.fill(1);
          return compute_mes_l;
        } else {
          return mesh_data.ll();
        }
      }();

      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId cell_id) {
          error[cell_id]      = (Temperature[cell_id] - Tj[cell_id]) * sqrt(primal_Vj[cell_id]);
          cell_error[cell_id] = (Temperature[cell_id] - Tj[cell_id]);
        });
      parallel_for(
        mesh->numberOfFaces(), PUGS_LAMBDA(FaceId face_id) {
          if (primal_face_is_on_boundary[face_id]) {
            face_error[face_id] = (Temperature_face[face_id] - Tl[face_id]) * sqrt(mes_l[face_id]);
          } else {
            face_error[face_id] = 0;
          }
        });
      for (CellId cell_id = 0; cell_id < mesh->numberOfCells(); cell_id++) {
        if (error_max < std::abs(cell_error[cell_id])) {
          error_max = std::abs(cell_error[cell_id]);
          cell_max  = cell_id;
        }
      }

      std::cout << " ||Error||_max (cell)= " << error_max << " on cell " << cell_max << "\n";
      std::cout << "||Error||_2 (cell)= " << std::sqrt(dot(error, error)) << "\n";
      std::cout << "||Error||_2 (face)= " << std::sqrt(dot(face_error, face_error)) << "\n";
      std::cout << "||Error||_2 (total)= " << std::sqrt(dot(error, error)) + std::sqrt(dot(face_error, face_error))
                << "\n";
    }
  } else {
    throw NotImplementedError("not done in 1d");
  }
}

template <size_t Dimension>
class ParabolicHeatScheme<Dimension>::DirichletBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  DirichletBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {
    Assert(m_value_list.size() == m_face_list.size());
  }

  ~DirichletBoundaryCondition() = default;
};

template <size_t Dimension>
class ParabolicHeatScheme<Dimension>::NeumannBoundaryCondition
{
 private:
  const Array<const double> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  NeumannBoundaryCondition(const Array<const FaceId>& face_list, const Array<const double>& value_list)
    : m_value_list{value_list}, m_face_list{face_list}
  {
    Assert(m_value_list.size() == m_face_list.size());
  }

  ~NeumannBoundaryCondition() = default;
};

template <size_t Dimension>
class ParabolicHeatScheme<Dimension>::FourierBoundaryCondition
{
 private:
  const Array<const double> m_coef_list;
  const Array<const double> m_value_list;
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

  const Array<const double>&
  valueList() const
  {
    return m_value_list;
  }

  const Array<const double>&
  coefList() const
  {
    return m_coef_list;
  }

 public:
  FourierBoundaryCondition(const Array<const FaceId>& face_list,
                           const Array<const double>& coef_list,
                           const Array<const double>& value_list)
    : m_coef_list{coef_list}, m_value_list{value_list}, m_face_list{face_list}
  {
    Assert(m_coef_list.size() == m_face_list.size());
    Assert(m_value_list.size() == m_face_list.size());
  }

  ~FourierBoundaryCondition() = default;
};

template <size_t Dimension>
class ParabolicHeatScheme<Dimension>::SymmetryBoundaryCondition
{
 private:
  const Array<const FaceId> m_face_list;

 public:
  const Array<const FaceId>&
  faceList() const
  {
    return m_face_list;
  }

 public:
  SymmetryBoundaryCondition(const Array<const FaceId>& face_list) : m_face_list{face_list} {}

  ~SymmetryBoundaryCondition() = default;
};

template <size_t Dimension>
class ParabolicHeatScheme<Dimension>::InterpolationWeightsManager
{
 private:
  std::shared_ptr<const Mesh<Connectivity<Dimension>>> m_mesh;
  FaceValue<bool> m_primal_face_is_on_boundary;
  NodeValue<bool> m_primal_node_is_on_boundary;
  CellValuePerNode<double> m_w_rj;
  FaceValuePerNode<double> m_w_rl;

 public:
  InterpolationWeightsManager(std::shared_ptr<const Mesh<Connectivity<Dimension>>> mesh,
                              FaceValue<bool> primal_face_is_on_boundary,
                              NodeValue<bool> primal_node_is_on_boundary)
    : m_mesh(mesh),
      m_primal_face_is_on_boundary(primal_face_is_on_boundary),
      m_primal_node_is_on_boundary(primal_node_is_on_boundary)
  {}
  ~InterpolationWeightsManager() = default;
  CellValuePerNode<double>&
  wrj()
  {
    return m_w_rj;
  }
  FaceValuePerNode<double>&
  wrl()
  {
    return m_w_rl;
  }
  void
  compute()
  {
    using MeshDataType      = MeshData<Dimension>;
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

    const NodeValue<const TinyVector<Dimension>>& xr = m_mesh->xr();

    const FaceValue<const TinyVector<Dimension>>& xl = mesh_data.xl();
    const CellValue<const TinyVector<Dimension>>& xj = mesh_data.xj();
    const auto& node_to_cell_matrix                  = m_mesh->connectivity().nodeToCellMatrix();
    const auto& node_to_face_matrix                  = m_mesh->connectivity().nodeToFaceMatrix();
    CellValuePerNode<double> w_rj{m_mesh->connectivity()};
    FaceValuePerNode<double> w_rl{m_mesh->connectivity()};

    for (size_t i = 0; i < w_rl.numberOfValues(); ++i) {
      w_rl[i] = std::numeric_limits<double>::signaling_NaN();
    }

    for (NodeId i_node = 0; i_node < m_mesh->numberOfNodes(); ++i_node) {
      SmallVector<double> b{Dimension + 1};
      b[0] = 1;
      for (size_t i = 1; i < Dimension + 1; i++) {
        b[i] = xr[i_node][i - 1];
      }
      const auto& node_to_cell = node_to_cell_matrix[i_node];

      if (not m_primal_node_is_on_boundary[i_node]) {
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size()};
        for (size_t j = 0; j < node_to_cell.size(); j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }
        SmallVector<double> x{node_to_cell.size()};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
      } else {
        int nb_face_used = 0;
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (m_primal_face_is_on_boundary[face_id]) {
            nb_face_used++;
          }
        }
        SmallMatrix<double> A{Dimension + 1, node_to_cell.size() + nb_face_used};
        for (size_t j = 0; j < node_to_cell.size() + nb_face_used; j++) {
          A(0, j) = 1;
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          for (size_t j = 0; j < node_to_cell.size(); j++) {
            const CellId J = node_to_cell[j];
            A(i, j)        = xj[J][i - 1];
          }
        }
        for (size_t i = 1; i < Dimension + 1; i++) {
          int cpt_face = 0;
          for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
            FaceId face_id = node_to_face_matrix[i_node][i_face];
            if (m_primal_face_is_on_boundary[face_id]) {
              A(i, node_to_cell.size() + cpt_face) = xl[face_id][i - 1];
              cpt_face++;
            }
          }
        }

        SmallVector<double> x{node_to_cell.size() + nb_face_used};
        x = zero;

        LeastSquareSolver ls_solver;
        ls_solver.solveLocalSystem(A, x, b);

        for (size_t j = 0; j < node_to_cell.size(); j++) {
          w_rj(i_node, j) = x[j];
        }
        int cpt_face = node_to_cell.size();
        for (size_t i_face = 0; i_face < node_to_face_matrix[i_node].size(); ++i_face) {
          FaceId face_id = node_to_face_matrix[i_node][i_face];
          if (m_primal_face_is_on_boundary[face_id]) {
            w_rl(i_node, i_face) = x[cpt_face++];
          }
        }
      }
    }
    m_w_rj = w_rj;
    m_w_rl = w_rl;
  }
};

template ParabolicHeatScheme<1>::ParabolicHeatScheme(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const double&,
  const double&);

template ParabolicHeatScheme<2>::ParabolicHeatScheme(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const double&,
  const double&);

template ParabolicHeatScheme<3>::ParabolicHeatScheme(
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const double&,
  const double&);