#include <language/algorithms/IterativerhoTrafficAlgorithm.hpp>

#include <algebra/BiCGStab.hpp>
#include <algebra/CRSMatrix.hpp>
#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/TinyMatrix.hpp>
#include <algebra/TinyVector.hpp>
#include <algebra/Vector.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <mesh/MeshVariant.hpp>
#include <output/VTKWriter.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/FreeBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>

template <MeshConcept MeshType>
IterativerhoTrafficAlgorithm<MeshType>::IterativerhoTrafficAlgorithm(
  std::shared_ptr<const MeshVariant> mesh_v,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& tau_id,
  const double& tmax,
  const double& given_dt)
{
  constexpr static size_t Dimension = MeshType::Dimension;

  using MeshDataType = MeshData<MeshType>;

  std::shared_ptr<const MeshType> mesh = mesh_v->get<MeshType>();

  using Rd = TinyVector<Dimension>;

  BoundaryConditionList bc_list;

  for (const auto& bc_descriptor : bc_descriptor_list) {
    bool is_valid_boundary_condition = true;

    switch (bc_descriptor->type()) {
    case IBoundaryConditionDescriptor::Type::dirichlet: {
      const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
        dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
      if (dirichlet_bc_descriptor.name() == "velocity") {
        MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

        Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
          TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                        mesh->xr(), mesh_node_boundary.nodeList());

        bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
      } else {
        is_valid_boundary_condition = false;
      }
      break;
    }
    case IBoundaryConditionDescriptor::Type::free: {   // <- Free boundary
      const FreeBoundaryConditionDescriptor& free_bc_descriptor =
        dynamic_cast<const FreeBoundaryConditionDescriptor&>(*bc_descriptor);

      MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, free_bc_descriptor.boundaryDescriptor());

      bc_list.push_back(FreeBoundaryCondition{mesh_node_boundary.nodeList()});
      break;
    }
    default: {
      is_valid_boundary_condition = false;
    }
    }
    if (not is_valid_boundary_condition) {
      std::ostringstream error_msg;
      error_msg << *bc_descriptor << " is an invalid boundary condition for implicit traffic flow solver";
      throw NormalError(error_msg.str());
    }
  }

  std::cout << "number of bc descr = " << bc_descriptor_list.size() << '\n';

  CellValue<double> tauj(mesh->connectivity());

  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    tauj =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(tau_id, mesh_data.xj());
  }

  CellValue<double> rhoj = [&]() {
    CellValue<double> computed_rhoj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_rhoj[j] = 1. / tauj[j]; });
    return computed_rhoj;
  }();

  CellValue<double> uj = [&]() {
    CellValue<double> computed_uj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_uj[j] = 1 - rhoj[j]; });
    return computed_uj;
  }();

  const CellValue<const double> Mj = [&]() {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(*mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();

    CellValue<double> computed_Mj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_Mj[j] = 1. / tauj[j] * Vj[j]; });

    return computed_Mj;
  }();

  const CellValue<const double> inv_Mj = [&]() {
    CellValue<double> computed_inv_Mj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_inv_Mj[j] = 1. / Mj[j]; });

    return computed_inv_Mj;
  }();

  double t    = 0;
  double dt   = given_dt;
  int itermax = std::numeric_limits<int>::max();

  int iteration = 0;

  while ((t < tmax) and (iteration < itermax)) {
    if (t + dt > tmax) {
      dt = tmax - t;
      t += dt;
      ++iteration;
    } else {
      t += dt;
      ++iteration;
    }

    const CellValue<double> tau_iter = copy(tauj);

    const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();
    const auto& cell_to_node_matrix = mesh->connectivity().cellToNodeMatrix();

    CellValue<double> tau_next(mesh->connectivity());

    CellValue<double> diff_tau = [&]() {
      CellValue<double> computed_abs_tau(mesh->connectivity());
      parallel_for(
        mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_abs_tau[j] = std::abs(tau_next[j] - tau_iter[j]); });
      return computed_abs_tau;
    }();

    int nb_iter = 0;

    while (max(diff_tau) > 10e-8 and nb_iter < 200) {
      std::cout << "iteration=" << nb_iter << '\n';
      nb_iter++;

      CellValue<double> bj(mesh->connectivity());

      for (CellId j = 0; j < mesh->numberOfCells() - 1; ++j) {
        bj[j] = 1. / (tau_iter[j] * tauj[j]);
      }

      // going from right to left cell
      for (const auto& boundary_condition : bc_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();
              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];
                const auto& node_cell = node_to_cell_matrix[node_id];
                CellId j              = node_cell[0];

                tau_next[j] =
                  (1 + dt * inv_Mj[j] * (1. / (tauj[j] * tau_iter[j]))) /
                  (1. / tauj[j] + dt * inv_Mj[j] * (1. / (tauj[j] * tau_iter[j])) * (1 - value_list[i_node][0]));

                for (double iter = 0; iter < mesh->numberOfCells() - 1; ++iter) {
                  const auto& cell_node   = cell_to_node_matrix[j];
                  NodeId r                = cell_node[0];
                  const auto& node_cell_2 = node_to_cell_matrix[r];
                  CellId j_1              = node_cell_2[0];

                  tau_next[j_1] =
                    (1 + dt * inv_Mj[j_1] * (1. / (tauj[j_1] * tau_iter[j_1]))) /
                    (1. / tauj[j_1] + dt * inv_Mj[j_1] * (1. / (tauj[j_1] * tau_iter[j_1])) * (1. / tau_next[j]));
                  j = j_1;
                }
              }
            }
          },
          boundary_condition);
      }

      for (CellId j = 0; j < mesh->connectivity().numberOfCells(); ++j) {
        diff_tau[j] = std::abs(tau_next[j] - tau_iter[j]);
      }
      std::cout << "diff=" << max(diff_tau) << '\n';

      // we take tau_j^nu = tau_j^nu+1
      for (CellId j = 0; j < mesh->connectivity().numberOfCells(); ++j) {
        tau_iter[j] = tau_next[j];
      }
    }

    /*  const NodeValue<double> new_x = [&]() {
      const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();
      const auto& cell_to_node_matrix = mesh->connectivity().cellToNodeMatrix();

      NodeValue<double> x_next(mesh->connectivity());
      for (const auto& boundary_condition : bc_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];

                x_next[node_id] = mesh->xr()[node_id][0] + dt * value_list[i_node][0];

                const auto& node_cell = node_to_cell_matrix[node_id];
                CellId j              = node_cell[0];
                const auto& cell_node = cell_to_node_matrix[j];
                NodeId r_1            = cell_node[0];
                x_next[r_1]           = x_next[node_id] - (tau_next[j] * Mj[j]);

                for (double iter = 1; iter < mesh->numberOfNodes() - 1; ++iter) {
                  const auto& node_cell_2 = node_to_cell_matrix[r_1];
                  CellId j_1              = node_cell_2[0];
                  const auto& cell_node_2 = cell_to_node_matrix[j_1];
                  NodeId r_2              = cell_node_2[0];

                  x_next[r_2] = x_next[r_1] - (tau_next[j_1] * Mj[j_1]);
                  r_1         = r_2;
                }
              }
            }
          },
          boundary_condition);
      }

      return x_next;
    }();
    */
    const NodeValue<Rd> new_position = [&]() {
      NodeValue<Rd> x_next(mesh->connectivity());

      {
        NodeId r               = 0;
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[0];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / tau_next[j]};
      }

      for (NodeId r = 1; r < mesh->connectivity().numberOfNodes() - 1; ++r) {
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[1];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / tau_next[j]};
      }

      for (const auto& boundary_condition : bc_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();
              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];

                x_next[node_id] = mesh->xr()[node_id] + dt * value_list[i_node];
              }
            }
          },
          boundary_condition);
      }
      return x_next;
    }();

    /*for (NodeId r = 0; r < mesh->connectivity().numberOfNodes(); ++r) {
      std::cout << "nouveau xj+12[" << r << "]=" << new_x[r] << '\n';
      std::cout << "nouvelle position xj+12[" << r << "]=" << new_position[r] << '\n';
      }*/

    NodeValue<Rd> new_xr = copy(mesh->xr());
    parallel_for(
      mesh->connectivity().numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] = new_position[r]; });
    mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr);
    std::cout.setf(std::cout.scientific);
    std::cout << "iteration " << rang::fg::cyan << std::setw(4) << iteration << rang::style::reset
              << " time=" << rang::fg::green << t << rang::style::reset << " dt=" << rang::fgB::blue << dt
              << rang::style::reset << '\n';
    for (CellId j = 0; j < mesh->connectivity().numberOfCells(); ++j) {
      tauj[j] = tau_next[j];
      rhoj[j] = 1. / tauj[j];
      uj[j]   = 1 - 1. / tauj[j];
      // std::cout << '\n' << "tauj[" << j << "]=" << tauj[j] << '\n';
      // std::cout << '\n' << "rhoj[" << j << "]=" << rhoj[j] << '\n';
      // std::cout << '\n' << "uj[" << j << "]=" << uj[j] << '\n';
    }
  }
}

template <MeshConcept MeshType>
class IterativerhoTrafficAlgorithm<MeshType>::VelocityBoundaryCondition
{
 private:
  constexpr static size_t Dimension = MeshType::Dimension;

  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class IterativerhoTrafficAlgorithm<MeshType>::FreeBoundaryCondition
{
 private:
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  FreeBoundaryCondition(const Array<const NodeId>& node_list) : m_node_list{node_list} {}

  ~FreeBoundaryCondition() = default;
};

template IterativerhoTrafficAlgorithm<Mesh<1>>::IterativerhoTrafficAlgorithm(
  std::shared_ptr<const MeshVariant> mesh_v,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& tau_id,
  const double& tmax,
  const double& given_dt);
