#include <scheme/UpwindExplicitTrafficSolver.hpp>

#include <language/utils/FunctionSymbolId.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <scheme/FreeBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
#include <variant>
#include <vector>

double
acoustic_dt_traffic_flow(const std::shared_ptr<const DiscreteFunctionVariant>& c_v,
                         const std::shared_ptr<const DiscreteFunctionVariant>& tau_v)
{
  std::shared_ptr mesh_v = getCommonMesh({c_v, tau_v});

  auto c   = c_v->get<DiscreteFunctionP0<const double>>();
  auto tau = tau_v->get<DiscreteFunctionP0<const double>>();

  return std::visit(
    [&](auto&& mesh) -> double {
      if constexpr (std::is_same_v<Mesh<1>, mesh_type_t<decltype(mesh)>>) {
        const auto Vj = MeshDataManager::instance().getMeshData(*mesh).Vj();

        CellValue<double> local_dt{mesh->connectivity()};
        parallel_for(
          mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { local_dt[j] = 0.5 * (Vj[j] / tau[j]) / c[j]; });

        return min(local_dt);
      } else {
        throw UnexpectedError("invalid mesh type");
      }
    },
    mesh_v->variant());
}

template <MeshConcept MeshType>
class UpwindExplicitTrafficSolverHandler ::UpwindExplicitTrafficSolver final
  : public UpwindExplicitTrafficSolverHandler::IUpwindExplicitTrafficSolver
{
  constexpr static size_t Dimension = MeshType::Dimension;

  using Rdxd = TinyMatrix<Dimension>;
  using Rd   = TinyVector<Dimension>;

  using MeshDataType = MeshData<MeshType>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class FreeBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition = std::variant<FreeBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList m_boundary_condition_list;
  const MeshType& m_mesh;

  NodeValue<const Rd> m_ur;
  NodeValue<const Rd> m_upwind_flux;
  CellValue<const double> m_tauj_carre;
  CellValue<const double> m_Mj;

  CellValue<const double>
  _getRho(const DiscreteScalarFunction& tau)
  {
    CellValue<double> rho{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { rho[j] = 1. / tau[j]; });
    return rho;
  }

  CellValue<const Rd>
  _getU(const DiscreteScalarFunction& tau)
  {
    CellValue<Rd> u{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { u[j][0] = 1 - 1. / tau[j]; });
    return u;
  }

  CellValue<const double>
  _getTaucarre(const DiscreteScalarFunction& tau)
  {
    CellValue<double> tau_carre{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { tau_carre[j] = tau[j] * tau[j]; });
    return tau_carre;
  }

  const CellValue<const double>
  _getMj(const DiscreteScalarFunction& tau)
  {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();
    CellValue<double> Mj{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { Mj[j] = Vj[j] * 1. / tau[j]; });
    return Mj;
  }

  const CellValue<const double>
  _getInv_Mj(const DiscreteScalarFunction& tau)
  {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();
    CellValue<double> inv_Mj{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { inv_Mj[j] = tau[j] * 1. / Vj[j]; });
    return inv_Mj;
  }

  NodeValue<const Rd>
  _getUpwindFlux(const DiscreteScalarFunction& tau)
  {
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    const auto& node_local_numbers_in_their_cells = m_mesh.connectivity().nodeLocalNumbersInTheirCells();
    NodeValue<Rd> computed_flux(m_mesh.connectivity());
    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) {
        const auto& node_cells = node_to_cell_matrix[r];
        const size_t nb_cells  = node_cells.size();
        double flux            = 0;
        for (size_t J = 0; J < nb_cells; ++J) {
          CellId j = node_cells[J];
          // we have f_j+1/2 = 1/tau_j+1 - 1 and f_j-1/2 = 1/tau_j -1 then we evaluate the flux from the left node
          // of the cell j.
          if (node_local_numbers_in_their_cells(r, J) == 0) {
            flux = 1. / tau[j];
          }
        }
        computed_flux[r] = Rd{flux - 1};
      });

    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list  = bc.nodeList();
            const auto& value_list = bc.valueList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];

              computed_flux[node_id] = -value_list[i_node];
            }
          }
        },
        boundary_condition);
    }

    return computed_flux;
  }

  BoundaryConditionList
  _getBCList(const std::shared_ptr<const MeshType>& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                          mesh->xr(), mesh_node_boundary.nodeList());

          bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::free: {   // <- Free boundary
        const FreeBoundaryConditionDescriptor& free_bc_descriptor =
          dynamic_cast<const FreeBoundaryConditionDescriptor&>(*bc_descriptor);

        MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, free_bc_descriptor.boundaryDescriptor());

        bc_list.push_back(FreeBoundaryCondition{mesh_node_boundary.nodeList()});
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  NodeValue<const Rd>
  _computeUr(const MeshType& mesh, NodeValue<const Rd> upwind_flux)
  {
    NodeValue<Rd> computed_ur(mesh.connectivity());
    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) { computed_ur[r] = -upwind_flux[r]; });

    return computed_ur;
  }

  UpwindExplicitTrafficSolver(
    const std::shared_ptr<const MeshType>& p_mesh,
    const DiscreteScalarFunction tau,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
    : m_boundary_condition_list{this->_getBCList(p_mesh, bc_descriptor_list)}, m_mesh{*p_mesh}
  {
    const MeshType& mesh = *p_mesh;

    CellValue<const double> rhoj   = this->_getRho(tau);
    CellValue<const Rd> uj         = this->_getU(tau);
    CellValue<const double> inv_Mj = this->_getInv_Mj(tau);

    m_upwind_flux = this->_getUpwindFlux(tau);
    m_ur          = this->_computeUr(mesh, m_upwind_flux);
    m_tauj_carre  = this->_getTaucarre(tau);
    m_Mj          = this->_getMj(tau);
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const MeshType>& mesh, const DiscreteScalarFunction& tau) const
  {
    const auto& cell_to_node_matrix = mesh->connectivity().cellToNodeMatrix();

    NodeValue<Rd> new_xr = copy(mesh->xr());
    parallel_for(
      m_mesh.numberOfNodes(), PUGS_LAMBDA(NodeId r) { new_xr[r] += dt * m_ur[r]; });

    std::shared_ptr<const MeshType> new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_xr);

    CellValue<const double> new_Vj = MeshDataManager::instance().getMeshData(*new_mesh).Vj();

    const NodeValuePerCell<const Rd> new_Cjr = MeshDataManager::instance().getMeshData(*new_mesh).Cjr();

    CellValue<double> new_tau = copy(tau.cellValues());

    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) {
        const auto& cell_nodes = cell_to_node_matrix[j];
        double flux_sum        = 0;
        for (size_t r = 0; r < cell_nodes.size(); ++r) {
          const NodeId nodeid = cell_nodes[r];
          flux_sum += dot(new_Cjr(j, r), m_upwind_flux[nodeid]);
        }
        const double dt_over_Mj = dt * tau[j] / new_Vj[j];
        new_tau[j] -= dt_over_Mj * flux_sum;
      });

    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, new_tau})};
  }

  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const DiscreteFunctionVariant>& tau) const
  {
    std::shared_ptr mesh_v = getCommonMesh({tau});
    if (not mesh_v) {
      throw NormalError("discrete function is not defined on the same mesh");
    }
    if (not checkDiscretizationType({tau}, DiscreteFunctionType::P0)) {
      throw NormalError("upwind traffic solver expects P0 functions");
    }
    return this->apply(dt, mesh_v->get<MeshType>(), tau->get<DiscreteScalarFunction>());
  }

  UpwindExplicitTrafficSolver(
    const std::shared_ptr<const MeshVariant>& mesh_v,
    const std::shared_ptr<const DiscreteFunctionVariant>& tau,
    const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
    : UpwindExplicitTrafficSolver{mesh_v->get<MeshType>(), tau->get<DiscreteScalarFunction>(), bc_descriptor_list}
  {}

  UpwindExplicitTrafficSolver()                              = default;
  UpwindExplicitTrafficSolver(UpwindExplicitTrafficSolver&&) = default;
  ~UpwindExplicitTrafficSolver()                             = default;
};

template <MeshConcept MeshType>
class UpwindExplicitTrafficSolverHandler::UpwindExplicitTrafficSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class UpwindExplicitTrafficSolverHandler::UpwindExplicitTrafficSolver<MeshType>::FreeBoundaryCondition
{
 private:
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  FreeBoundaryCondition(const Array<const NodeId>& node_list) : m_node_list{node_list} {}

  ~FreeBoundaryCondition() = default;
};

UpwindExplicitTrafficSolverHandler::UpwindExplicitTrafficSolverHandler(
  const std::shared_ptr<const DiscreteFunctionVariant>& tau,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
{
  std::shared_ptr mesh_v = getCommonMesh({tau});
  if (not mesh_v) {
    throw NormalError("discrete function is not defined on the same mesh");
  }

  if (not checkDiscretizationType({tau}, DiscreteFunctionType::P0)) {
    throw NormalError("Upwind solver expects P0 functions");
  }

  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (std::is_same_v<MeshType, Mesh<1>>) {
        m_upwind_explicit_traffic_solver =
          std::make_unique<UpwindExplicitTrafficSolver<MeshType>>(mesh_v, tau, bc_descriptor_list);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    mesh_v->variant());
}
