#include <language/algorithms/UpwindExplicitTrafficAlgorithm.hpp>

#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <output/VTKWriter.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/FreeBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>

template <MeshConcept MeshType>
UpwindExplicitTrafficAlgorithm<MeshType>::UpwindExplicitTrafficAlgorithm(
  std::shared_ptr<const MeshVariant> mesh_v,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& tau_id,
  const double& tmax)
{
  constexpr static size_t Dimension = MeshType::Dimension;

  using MeshDataType = MeshData<MeshType>;

  std::shared_ptr<const MeshType> mesh = mesh_v->get<MeshType>();

  using Rd = TinyVector<Dimension>;

  BoundaryConditionList bc_list;

  for (const auto& bc_descriptor : bc_descriptor_list) {
    bool is_valid_boundary_condition = true;

    switch (bc_descriptor->type()) {
    case IBoundaryConditionDescriptor::Type::dirichlet: {
      const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
        dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
      if (dirichlet_bc_descriptor.name() == "velocity") {
        MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

        Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
          TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                        mesh->xr(), mesh_node_boundary.nodeList());

        bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
      } else {
        is_valid_boundary_condition = false;
      }
      break;
    }
    case IBoundaryConditionDescriptor::Type::free: {   // <- Free boundary
      const FreeBoundaryConditionDescriptor& free_bc_descriptor =
        dynamic_cast<const FreeBoundaryConditionDescriptor&>(*bc_descriptor);

      MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, free_bc_descriptor.boundaryDescriptor());

      bc_list.push_back(FreeBoundaryCondition{mesh_node_boundary.nodeList()});
      break;
    }
    default: {
      is_valid_boundary_condition = false;
    }
    }
    if (not is_valid_boundary_condition) {
      std::ostringstream error_msg;
      error_msg << *bc_descriptor << " is an invalid boundary condition for upwind traffic flow solver";
      throw NormalError(error_msg.str());
    }
  }

  std::cout << "number of bc descr = " << bc_descriptor_list.size() << '\n';

  CellValue<double> tauj(mesh->connectivity());

  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    tauj =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(tau_id, mesh_data.xj());
  }

  for (CellId j = 0; j < mesh->numberOfCells(); ++j) {
    // std::cout << "tau[" << j << "] = ";
    // std::cout << tauj[j] << '\n';
  }

  CellValue<double> rhoj = [&]() {
    CellValue<double> computed_rhoj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_rhoj[j] = 1. / tauj[j]; });
    return computed_rhoj;
  }();
  CellValue<double> uj = [&]() {
    CellValue<double> computed_uj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_uj[j] = 1 - rhoj[j]; });
    return computed_uj;
  }();

  CellValue<double> tauj_carre = [&]() {
    CellValue<double> computed_tauj_carre(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_tauj_carre[j] = tauj[j] * tauj[j]; });
    return computed_tauj_carre;
  }();

  double c = 0;
  c        = 1. / min(tauj_carre);

  // const double tmax = 2;
  double t = 0;

  int itermax   = std::numeric_limits<int>::max();
  int iteration = 0;

  const CellValue<const double> Mj = [&]() {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(*mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();

    CellValue<double> computed_Mj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_Mj[j] = 1. / tauj[j] * Vj[j]; });

    return computed_Mj;
  }();

  const CellValue<const double> inv_Mj = [&]() {
    CellValue<double> computed_inv_Mj(mesh->connectivity());
    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { computed_inv_Mj[j] = 1. / Mj[j]; });

    return computed_inv_Mj;
  }();

  while ((t < tmax) and (iteration < itermax)) {
    double dt = 0.5 * min(Mj) / c;
    if (t + dt > tmax) {
      dt = tmax - t;
      t += dt;
      ++iteration;
    } else {
      t += dt;
      ++iteration;
    }

    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    const NodeValue<const TinyVector<1, double>> upwindflux = [&]() {
      const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();

      const auto& node_local_numbers_in_their_cells = mesh->connectivity().nodeLocalNumbersInTheirCells();
      NodeValue<TinyVector<1, double>> computed_flux(mesh->connectivity());
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) {
          const auto& node_cells = node_to_cell_matrix[r];
          const size_t nb_cells  = node_cells.size();
          double flux            = 0;
          for (size_t J = 0; J < nb_cells; ++J) {
            CellId j = node_cells[J];
            // we have f_j+1/2 = 1/tau_j+1 - 1 and f_j-1/2 = 1/tau_j -1 then we evaluate the flux from the left node
            // of the cell j.
            if (node_local_numbers_in_their_cells(r, J) == 0) {
              flux = 1. / tauj[j];
            }
          }
          computed_flux[r] = Rd{flux - 1};
        });

      for (const auto& boundary_condition : bc_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();

              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];

                computed_flux[node_id] = -value_list[i_node];
              }
            }
          },
          boundary_condition);
      }

      return computed_flux;
    }();

    const NodeValuePerCell<const TinyVector<1, double>> Cjr = mesh_data.Cjr();

    const auto& cell_to_node_matrix = mesh->connectivity().cellToNodeMatrix();

    const NodeValue<const TinyVector<1, double>> ur = [&]() {
      NodeValue<TinyVector<1, double>> computed_ur(mesh->connectivity());
      parallel_for(
        mesh->numberOfNodes(), PUGS_LAMBDA(NodeId r) { computed_ur[r] = -upwindflux[r]; });

      return computed_ur;
    }();

    for (NodeId r = 0; r < mesh->numberOfNodes(); ++r) {
      // std::cout << "ur[" << r << "]=" << ur[r] << '\n';
    }
    NodeValue<Rd> new_xr = copy(mesh->xr());
    parallel_for(
      mesh->connectivity().numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] += dt * ur[r]; });
    mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr);

    for (CellId j = 0; j < mesh->connectivity().numberOfCells(); ++j) {
      const auto& cell_nodes = cell_to_node_matrix[j];
      double flux_sum        = 0;
      for (size_t r = 0; r < cell_nodes.size(); ++r) {
        const NodeId nodeid = cell_nodes[r];
        flux_sum += dot(Cjr(j, r), upwindflux[nodeid]);
      }

      tauj[j] -= dt * inv_Mj[j] * flux_sum;
      rhoj[j] = 1. / tauj[j];
      uj[j]   = 1 - rhoj[j];
      // std::cout << "tauj[" << j << "]=" << tauj[j] << '\n';
    }
  }
}

template <MeshConcept MeshType>
class UpwindExplicitTrafficAlgorithm<MeshType>::VelocityBoundaryCondition
{
 private:
  constexpr static size_t Dimension = MeshType::Dimension;

  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class UpwindExplicitTrafficAlgorithm<MeshType>::FreeBoundaryCondition
{
 private:
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  FreeBoundaryCondition(const Array<const NodeId>& node_list) : m_node_list{node_list} {}

  ~FreeBoundaryCondition() = default;
};

template UpwindExplicitTrafficAlgorithm<Mesh<1>>::UpwindExplicitTrafficAlgorithm(
  std::shared_ptr<const MeshVariant>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const double&);
