#include <scheme/ImplicitIterativeTrafficSolver.hpp>

#include <algebra/CRSMatrix.hpp>
#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/Vector.hpp>
#include <language/utils/FunctionSymbolId.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <scheme/FreeBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <variant>
#include <vector>

template <MeshConcept MeshType>
class ImplicitIterativeTrafficSolverHandler::ImplicitIterativeTrafficSolver final
  : public ImplicitIterativeTrafficSolverHandler::IImplicitIterativeTrafficSolver
{
  constexpr static size_t Dimension = MeshType::Dimension;

  using Rdxd = TinyMatrix<Dimension>;
  using Rd   = TinyVector<Dimension>;

  using MeshDataType = MeshData<MeshType>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class FreeBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition = std::variant<FreeBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList m_boundary_condition_list;
  const MeshType& m_mesh;

  NodeValue<const Rd> m_ur;
  NodeValue<const Rd> m_upwind_flux;
  CellValue<const double> m_inv_Mj;

  CellValue<double> m_tau;

  CellValue<const double>
  _getRho(const DiscreteScalarFunction& tau)
  {
    CellValue<double> rho{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { rho[j] = 1. / tau[j]; });
    return rho;
  }

  CellValue<const double>
  _getU(const DiscreteScalarFunction& tau)
  {
    CellValue<double> u{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { u[j] = 1 - 1. / tau[j]; });
    return u;
  }

  const CellValue<const double>
  _getMj(const DiscreteScalarFunction& tau)
  {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();
    CellValue<double> Mj{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { Mj[j] = Vj[j] * 1. / tau[j]; });
    return Mj;
  }

  const CellValue<const double>
  _getInv_Mj(const DiscreteScalarFunction& tau)
  {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();
    CellValue<double> inv_Mj{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { inv_Mj[j] = tau[j] * 1. / Vj[j]; });
    return inv_Mj;
  }

  CellValue<double>
  _getNextTau(CellValue<const double> taun, CellValue<const double> tau_iter, const double dt)
  {
    CellValue<double> tau_next(m_mesh.connectivity());

    CellValue<double> bj(m_mesh.connectivity());

    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    for (CellId j = 0; j < m_mesh.numberOfCells() - 1; ++j) {
      const auto& cell_node = cell_to_node_matrix[j];
      NodeId r              = cell_node[1];
      const auto& node_cell = node_to_cell_matrix[r];
      CellId j1             = node_cell[1];
      bj[j]                 = 1. / (tau_iter[j] * tau_iter[j1]);
    }
    // going from right to left cell
    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list  = bc.nodeList();
            const auto& value_list = bc.valueList();
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];
              const auto& node_cell = node_to_cell_matrix[node_id];
              CellId j              = node_cell[0];
              tau_next[j]           = (taun[j] + ((dt * m_inv_Mj[j]) / tau_iter[j])) /
                            (1 + dt * m_inv_Mj[j] * ((1 - value_list[i_node][0]) / tau_iter[j]));

              for (double iter = 0; iter < m_mesh.numberOfCells() - 1; ++iter) {
                const auto& cell_node   = cell_to_node_matrix[j];
                NodeId r                = cell_node[0];
                const auto& node_cell_2 = node_to_cell_matrix[r];
                CellId j_1              = node_cell_2[0];

                tau_next[j_1] =
                  (taun[j_1] + dt * m_inv_Mj[j_1] * bj[j_1] * tau_next[j]) / (1 + dt * m_inv_Mj[j_1] * bj[j_1]);
                j = j_1;
              }
            }
          }
        },
        boundary_condition);
    }

    return tau_next;
  }

  BoundaryConditionList
  _getBCList(const std::shared_ptr<const MeshType>& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                          mesh->xr(), mesh_node_boundary.nodeList());

          bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::free: {   // <- Free boundary
        const FreeBoundaryConditionDescriptor& free_bc_descriptor =
          dynamic_cast<const FreeBoundaryConditionDescriptor&>(*bc_descriptor);

        MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, free_bc_descriptor.boundaryDescriptor());

        bc_list.push_back(FreeBoundaryCondition{mesh_node_boundary.nodeList()});
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  ImplicitIterativeTrafficSolver(
    const std::shared_ptr<const MeshType>& p_mesh,
    const DiscreteScalarFunction& tau,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
    const double& dt)
    : m_boundary_condition_list{this->_getBCList(p_mesh, bc_descriptor_list)},
      m_mesh{*p_mesh}   // m_tau{0} pour vecteur initial de solution?
  {
    m_inv_Mj = this->_getInv_Mj(tau);

    CellValue<const double> taun = [&] {
      CellValue<double> compute_taun(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { compute_taun[j] = tau[j]; });
      return compute_taun;
    }();

    CellValue<const double> tau_iter = copy(taun);

    int nb_iter = 0;
    double diff_tau;

    do {
      std::cout << "iteration=" << nb_iter << '\n';
      nb_iter++;

      CellValue<double> next_tau = this->_getNextTau(taun, tau_iter, dt);

      diff_tau = [&]() {
        CellValue<double> computed_abs_tau(m_mesh.connectivity());
        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_abs_tau[j] = std::abs(next_tau[j] - tau_iter[j]); });

        return max(computed_abs_tau);
      }();

      tau_iter = next_tau;

    } while ((diff_tau > 10e-8) and (nb_iter < 200));

    m_tau = [&] {
      CellValue<double> computed_tau(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_tau[j] = tau_iter[j]; });
      return computed_tau;
    }();
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const MeshType>& mesh, const DiscreteScalarFunction& tau) const
  {
    const NodeValue<Rd> new_position = [&]() {
      const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();

      NodeValue<Rd> x_next(mesh->connectivity());

      {
        NodeId r               = 0;
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[0];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / m_tau[j]};
      }
      for (NodeId r = 1; r < mesh->connectivity().numberOfNodes() - 1; ++r) {
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[1];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / m_tau[j]};
      }

      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();
              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];

                x_next[node_id] = mesh->xr()[node_id] + dt * value_list[i_node];
              }
            }
          },
          boundary_condition);
      }
      return x_next;
    }();

    std::shared_ptr<const MeshType> new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_position);

    CellValue<double> new_tau = copy(tau.cellValues());
    parallel_for(
      new_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { new_tau[j] = m_tau[j]; });

    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, m_tau})};
  }

  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const DiscreteFunctionVariant>& tau_v) const
  {
    std::shared_ptr mesh_v = getCommonMesh({tau_v});
    if (not mesh_v) {
      throw NormalError("discrete function is not defined on the same mesh");
    }

    if (not checkDiscretizationType({tau_v}, DiscreteFunctionType::P0)) {
      throw NormalError("traffic solver expects P0 functions");
    }
    return this->apply(dt, mesh_v->get<MeshType>(), tau_v->get<DiscreteScalarFunction>());
  }

  ImplicitIterativeTrafficSolver(
    const std::shared_ptr<const MeshVariant>& mesh_v,
    const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
    const double& dt)
    : ImplicitIterativeTrafficSolver(mesh_v->get<MeshType>(),
                                     tau_v->get<DiscreteScalarFunction>(),
                                     bc_descriptor_list,
                                     dt)
  {}
  ImplicitIterativeTrafficSolver()                                 = default;
  ImplicitIterativeTrafficSolver(ImplicitIterativeTrafficSolver&&) = default;
  ~ImplicitIterativeTrafficSolver()                                = default;
};

template <MeshConcept MeshType>
class ImplicitIterativeTrafficSolverHandler::ImplicitIterativeTrafficSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  constexpr static size_t Dimension = MeshType::Dimension;

  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class ImplicitIterativeTrafficSolverHandler::ImplicitIterativeTrafficSolver<MeshType>::FreeBoundaryCondition
{
 private:
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  FreeBoundaryCondition(const Array<const NodeId>& node_list) : m_node_list{node_list} {}

  ~FreeBoundaryCondition() = default;
};

ImplicitIterativeTrafficSolverHandler::ImplicitIterativeTrafficSolverHandler(
  const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const double& dt)
{
  std::shared_ptr mesh_v = getCommonMesh({tau_v});
  if (not mesh_v) {
    throw NormalError("discrete function is not defined on the same mesh");
  }
  if (not checkDiscretizationType({tau_v}, DiscreteFunctionType::P0)) {
    throw NormalError("traffic solver expects P0 functions");
  }

  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (std::is_same_v<MeshType, Mesh<1>>) {
        m_implicit_iterative_traffic_solver =
          std::make_unique<ImplicitIterativeTrafficSolver<MeshType>>(mesh_v, tau_v, bc_descriptor_list, dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    mesh_v->variant());
}
