#include <scheme/UpwindImplicitTrafficSolver.hpp>

#include <algebra/CRSMatrix.hpp>
#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/Vector.hpp>
#include <language/utils/FunctionSymbolId.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <scheme/FreeBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <variant>
#include <vector>

template <MeshConcept MeshType>
class UpwindImplicitTrafficSolverHandler::UpwindImplicitTrafficSolver final
  : public UpwindImplicitTrafficSolverHandler::IUpwindImplicitTrafficSolver
{
  constexpr static size_t Dimension = MeshType::Dimension;

  using Rdxd = TinyMatrix<Dimension>;
  using Rd   = TinyVector<Dimension>;

  using MeshDataType = MeshData<MeshType>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class FreeBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition = std::variant<FreeBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList m_boundary_condition_list;
  const MeshType& m_mesh;

  NodeValue<const Rd> m_ur;
  NodeValue<const Rd> m_upwind_flux;
  CellValue<const double> m_inv_Mj;

  CellValue<double> m_tau;

  CellValue<const double>
  _getRho(const DiscreteScalarFunction& tau)
  {
    CellValue<double> rho{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { rho[j] = 1. / tau[j]; });
    return rho;
  }

  CellValue<const double>
  _getU(const DiscreteScalarFunction& tau)
  {
    CellValue<double> u{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { u[j] = 1 - 1. / tau[j]; });
    return u;
  }

  const CellValue<const double>
  _getMj(const DiscreteScalarFunction& tau)
  {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();
    CellValue<double> Mj{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { Mj[j] = Vj[j] * 1. / tau[j]; });
    return Mj;
  }

  // function f
  Vector<double>
  _getF(const Vector<const double>& taun, const Vector<const double>& tauk, const double dt)
  {
    Vector<double> function_f{m_mesh.numberOfCells()};

    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    CellValue<double> computed_f(m_mesh.connectivity());
    for (CellId j = 0; j < m_mesh.numberOfCells() - 1; ++j) {
      const auto& cell_node = cell_to_node_matrix[j];
      NodeId r              = cell_node[1];
      const auto& node_cell = node_to_cell_matrix[r];
      CellId j1             = node_cell[1];
      computed_f[j]         = tauk[j] - taun[j] + (dt * m_inv_Mj[j]) / tauk[j1] - (dt * m_inv_Mj[j]) / tauk[j];
    }

    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list  = bc.nodeList();
            const auto& value_list = bc.valueList();
            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];
              const auto& node_cell = node_to_cell_matrix[node_id];
              CellId j              = node_cell[0];
              computed_f[j] =
                tauk[j] - taun[j] + (dt * m_inv_Mj[j]) * (1 - value_list[i_node][0]) - (dt * m_inv_Mj[j]) / tauk[j];
            }
          }
        },
        boundary_condition);
    }

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        size_t i      = j;
        function_f[i] = computed_f[j];
      });

    return function_f;
  }

  CRSMatrix<double>
  _getGradientF(const Vector<const double>& tauk, const double dt)
  {
    Array<int> non_zeros{m_mesh.numberOfCells()};
    non_zeros.fill(2);
    CRSMatrixDescriptor grad_f{m_mesh.numberOfCells(), non_zeros};

    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list = bc.nodeList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];
              const auto& node_cell = node_to_cell_matrix[node_id];
              CellId j              = node_cell[0];

              size_t number_of_the_cell = j;
              int row                   = number_of_the_cell;
              int k                     = 0;
              grad_f(k, row) -= (dt * m_inv_Mj[j]) / (tauk[j] * tauk[j]);
            }
          }
        },
        boundary_condition);
    }

    for (CellId j = 0; j < m_mesh.numberOfCells(); ++j) {
      size_t number_of_the_cell = j;
      int row                   = number_of_the_cell;

      grad_f(row, row) = 1 + (dt * m_inv_Mj[j]) / (tauk[j] * tauk[j]);
    }

    for (CellId j = 0; j < m_mesh.numberOfCells() - 1; ++j) {
      const auto& cell_node = cell_to_node_matrix[j];
      NodeId r              = cell_node[1];
      const auto& node_cell = node_to_cell_matrix[r];
      CellId j1             = node_cell[1];

      size_t number_of_the_cell = j;
      int row                   = number_of_the_cell;

      grad_f(row, row + 1) = -dt * m_inv_Mj[j] / (tauk[j1] * tauk[j1]);
    }

    return grad_f.getCRSMatrix();
  }

  BoundaryConditionList
  _getBCList(const std::shared_ptr<const MeshType>& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                          mesh->xr(), mesh_node_boundary.nodeList());

          bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::free: {   // <- Free boundary
        const FreeBoundaryConditionDescriptor& free_bc_descriptor =
          dynamic_cast<const FreeBoundaryConditionDescriptor&>(*bc_descriptor);

        MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, free_bc_descriptor.boundaryDescriptor());

        bc_list.push_back(FreeBoundaryCondition{mesh_node_boundary.nodeList()});
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  UpwindImplicitTrafficSolver(
    const std::shared_ptr<const MeshType>& p_mesh,
    const DiscreteScalarFunction& tau,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
    const double& dt)
    : m_boundary_condition_list{this->_getBCList(p_mesh, bc_descriptor_list)},
      m_mesh{*p_mesh},   // m_tau{0} pour vecteur initial de solution?
      m_inv_Mj{[&] {
        MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
        const CellValue<const double>& Vj = mesh_data.Vj();
        CellValue<double> inv_Mj{m_mesh.connectivity()};
        parallel_for(
          m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { inv_Mj[j] = tau[j] * 1. / Vj[j]; });
        return inv_Mj;
      }()}
  {
    Vector<const double> taun = [&] {
      Vector<double> compute_taun{tau.cellValues().numberOfItems()};
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { compute_taun[j] = tau[j]; });

      return compute_taun;
    }();

    Vector<double> tauk = copy(taun);

    int nb_iter = 0;
    double norm_inf_sol;

    CellValue<double> abs_taun = [&]() {
      CellValue<double> computed_abs_taun(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_abs_taun[j] = std::abs(taun[j]); });
      return computed_abs_taun;
    }();

    double norm_inf_taun = max(abs_taun);

    do {
      std::cout << "iteration=" << nb_iter << '\n';
      nb_iter++;

      Vector<double> f             = this->_getF(taun, tauk, dt);
      CRSMatrix<double> gradient_f = this->_getGradientF(tauk, dt);

      Vector<double> sol{m_mesh.numberOfCells()};
      sol.fill(1);

      LinearSolver solver;
      solver.solveLocalSystem(gradient_f, sol, f);

      Vector<double> tau_next = tauk - sol;

      Array<const double> abs_sol = [&]() {
        Array<double> compute_abs_sol{sol.size()};
        parallel_for(
          sol.size(), PUGS_LAMBDA(size_t i) { compute_abs_sol[i] = std::abs(sol[i]); });
        return compute_abs_sol;
      }();

      norm_inf_sol = max(abs_sol);

      tauk = tau_next;
      std::cout << "ratio" << norm_inf_sol / norm_inf_taun << "\n";

    } while ((norm_inf_sol > 1e-14 * norm_inf_taun) and (nb_iter < 10000));

    m_tau = [&] {
      CellValue<double> computed_tau(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_tau[j] = tauk[j]; });
      return computed_tau;
    }();
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const MeshType>& mesh, const DiscreteScalarFunction& tau) const
  {
    MeshDataType& mesh_data              = MeshDataManager::instance().getMeshData(*mesh);
    const NodeValuePerCell<const Rd> Cjr = mesh_data.Cjr();

    NodeValue<Rd> new_position = [&]() {
      const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();

      NodeValue<Rd> x_next(mesh->connectivity());

      {
        NodeId r               = 0;
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[0];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / m_tau[j]};
      }
      for (NodeId r = 1; r < mesh->connectivity().numberOfNodes() - 1; ++r) {
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[1];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / m_tau[j]};
      }

      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();
              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];

                x_next[node_id] = mesh->xr()[node_id] + dt * value_list[i_node];
              }
            }
          },
          boundary_condition);
      }
      return x_next;
    }();

    // NodeValue<Rd> new_xr = copy(mesh->xr());
    // parallel_for(
    //   mesh->connectivity().numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] = new_position[r]; });
    // mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr);
    // for (CellId j = 0; j < mesh->connectivity().numberOfCells(); ++j) {
    //   tau[j] = m_tau[j];

    //   // std::cout << '\n' << "tauj[" << j << "]=" << tauj[j] << '\n';
    //   // std::cout << '\n' << "rhoj[" << j << "]=" << rhoj[j] << '\n';
    //   // std::cout << '\n' << "uj[" << j << "]=" << uj[j] << '\n';
    // }

    std::shared_ptr<const MeshType> new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_position);

    // CellValue<const double> Vj = MeshDataManager::instance().getMeshData(*mesh).Vj();
    CellValue<double> new_tau = copy(tau.cellValues());
    parallel_for(
      new_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { new_tau[j] = m_tau[j]; });
    // CellValue<const double> new_Vj = MeshDataManager::instance().getMeshData(*new_mesh).Vj();
    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, m_tau})};
  }

  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const DiscreteFunctionVariant>& tau) const
  {
    std::shared_ptr mesh_v = getCommonMesh({tau});
    if (not mesh_v) {
      throw NormalError("discrete function is not defined on the same mesh");
    }

    if (not checkDiscretizationType({tau}, DiscreteFunctionType::P0)) {
      throw NormalError("traffic solver expects P0 functions");
    }
    return this->apply(dt, mesh_v->get<MeshType>(), tau->get<DiscreteScalarFunction>());
  }

  UpwindImplicitTrafficSolver(
    const std::shared_ptr<const MeshVariant>& mesh_v,
    const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
    const double& dt)
    : UpwindImplicitTrafficSolver(mesh_v->get<MeshType>(), tau_v->get<DiscreteScalarFunction>(), bc_descriptor_list, dt)
  {}
  UpwindImplicitTrafficSolver()                              = default;
  UpwindImplicitTrafficSolver(UpwindImplicitTrafficSolver&&) = default;
  ~UpwindImplicitTrafficSolver()                             = default;
};

template <MeshConcept MeshType>
class UpwindImplicitTrafficSolverHandler::UpwindImplicitTrafficSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class UpwindImplicitTrafficSolverHandler::UpwindImplicitTrafficSolver<MeshType>::FreeBoundaryCondition
{
 private:
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  FreeBoundaryCondition(const Array<const NodeId>& node_list) : m_node_list{node_list} {}

  ~FreeBoundaryCondition() = default;
};

UpwindImplicitTrafficSolverHandler::UpwindImplicitTrafficSolverHandler(
  const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const double& dt)
{
  std::shared_ptr mesh_v = getCommonMesh({tau_v});
  if (not mesh_v) {
    throw NormalError("discrete function is not defined on the same mesh");
  }
  if (not checkDiscretizationType({tau_v}, DiscreteFunctionType::P0)) {
    throw NormalError("traffic solver expects P0 functions");
  }

  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (std::is_same_v<MeshType, Mesh<1>>) {
        m_upwind_implicit_traffic_solver =
          std::make_unique<UpwindImplicitTrafficSolver<MeshType>>(mesh_v, tau_v, bc_descriptor_list, dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    mesh_v->variant());
}
