#include <scheme/ImplicitExactTrafficSolver.hpp>

#include <algebra/CRSMatrix.hpp>
#include <algebra/CRSMatrixDescriptor.hpp>
#include <algebra/LinearSolver.hpp>
#include <algebra/Vector.hpp>
#include <language/utils/FunctionSymbolId.hpp>
#include <language/utils/InterpolateItemValue.hpp>
#include <mesh/ItemValueUtils.hpp>
#include <mesh/MeshNodeBoundary.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVariant.hpp>
#include <scheme/FreeBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <variant>
#include <vector>

template <MeshConcept MeshType>
class ImplicitExactTrafficSolverHandler::ImplicitExactTrafficSolver final
  : public ImplicitExactTrafficSolverHandler::IImplicitExactTrafficSolver
{
  constexpr static size_t Dimension = MeshType::Dimension;

  using Rdxd = TinyMatrix<Dimension>;
  using Rd   = TinyVector<Dimension>;

  using MeshDataType = MeshData<MeshType>;

  using DiscreteScalarFunction = DiscreteFunctionP0<const double>;
  using DiscreteVectorFunction = DiscreteFunctionP0<const Rd>;

  class FreeBoundaryCondition;
  class VelocityBoundaryCondition;

  using BoundaryCondition = std::variant<FreeBoundaryCondition, VelocityBoundaryCondition>;

  using BoundaryConditionList = std::vector<BoundaryCondition>;

  BoundaryConditionList m_boundary_condition_list;
  const MeshType& m_mesh;

  // NodeValue<const Rd> m_ur;
  //  CellValue<const double> m_inv_Mj;
  CellValue<const double> m_Mj;

  CellValue<double> m_tau;

  // CellValue<const double>
  // _getRho(const DiscreteScalarFunction& tau)
  // {
  //   CellValue<double> rho{m_mesh.connectivity()};
  //   parallel_for(
  //     m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { rho[j] = 1. / tau[j]; });
  //   return rho;
  // }

  // CellValue<const double>
  // _getU(const DiscreteScalarFunction& tau)
  // {
  //   CellValue<double> u{m_mesh.connectivity()};
  //   parallel_for(
  //     m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { u[j] = 1 - 1. / tau[j]; });
  //   return u;
  // }

  const CellValue<const double>
  _getMj(const DiscreteScalarFunction& tau)
  {
    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();
    CellValue<double> Mj{m_mesh.connectivity()};
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { Mj[j] = Vj[j] * 1. / tau[j]; });
    return Mj;
  }

  // new_tau
  CellValue<const double>
  _getExactTau(const DiscreteScalarFunction& tau, const double dt)
  {
    const auto& cell_to_node_matrix = m_mesh.connectivity().cellToNodeMatrix();
    const auto& node_to_cell_matrix = m_mesh.connectivity().nodeToCellMatrix();

    CellValue<double> Bj(m_mesh.connectivity());
    CellValue<double> deltaj(m_mesh.connectivity());
    CellValue<double> taunext(m_mesh.connectivity());
    CellValue<double> inv_Mj(m_mesh.connectivity());

    MeshDataType& mesh_data           = MeshDataManager::instance().getMeshData(m_mesh);
    const CellValue<const double>& Vj = mesh_data.Vj();
    parallel_for(
      m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { inv_Mj[j] = tau[j] / Vj[j]; });

    for (const auto& boundary_condition : m_boundary_condition_list) {
      std::visit(
        [&](auto&& bc) {
          using T = std::decay_t<decltype(bc)>;
          if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
            const auto& node_list  = bc.nodeList();
            const auto& value_list = bc.valueList();

            for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
              const NodeId& node_id = node_list[i_node];
              const auto& node_cell = node_to_cell_matrix[node_id];
              CellId j              = node_cell[0];

              Bj[j] = tau[j] - dt * inv_Mj[j] * (1 - value_list[i_node][0]);

              deltaj[j] = Bj[j] * Bj[j] + 4 * dt * inv_Mj[j];

              if (deltaj[j] > 0) {
                double x_1 = (Bj[j] + sqrt(deltaj[j])) / 2;
                double x_2 = (Bj[j] - sqrt(deltaj[j])) / 2;

                if (x_1 > 0) {
                  taunext[j] = x_1;
                } else {
                  taunext[j] = x_2;
                }
              } else if (deltaj[j] == 0) {
                double x   = Bj[j] / 2;
                taunext[j] = x;
              } else {
                throw UnexpectedError("cannot find root");

                std::exit(0);
              }

              for (double iter = 1; iter < m_mesh.numberOfCells(); ++iter) {
                const auto& cell_node   = cell_to_node_matrix[j];
                NodeId r                = cell_node[0];
                const auto& node_cell_2 = node_to_cell_matrix[r];
                CellId j_1              = node_cell_2[0];

                Bj[j_1] = tau[j_1] - (dt * inv_Mj[j_1]) * (1. / taunext[j]);

                deltaj[j_1] = Bj[j_1] * Bj[j_1] + 4 * (dt * inv_Mj[j_1]);

                if (deltaj[j_1] > 0) {
                  double x_1 = (Bj[j_1] + sqrt(deltaj[j_1])) / 2;
                  double x_2 = (Bj[j_1] - sqrt(deltaj[j_1])) / 2;

                  if (x_1 > 0) {
                    taunext[j_1] = x_1;
                  } else {
                    taunext[j_1] = x_2;
                  }
                } else if (deltaj[j_1] == 0) {
                  double x     = Bj[j_1] / 2;
                  taunext[j_1] = x;
                } else {
                  throw UnexpectedError("cannot find root");
                  std::exit(0);
                }

                j = j_1;
              }
            }
          }
        },
        boundary_condition);
    }
    return taunext;
  }

  BoundaryConditionList
  _getBCList(const std::shared_ptr<const MeshType>& mesh,
             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
  {
    BoundaryConditionList bc_list;

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          MeshNodeBoundary mesh_node_boundary =
            getMeshNodeBoundary(*mesh, dirichlet_bc_descriptor.boundaryDescriptor());

          Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
            TinyVector<Dimension>)>::template interpolate<ItemType::node>(dirichlet_bc_descriptor.rhsSymbolId(),
                                                                          mesh->xr(), mesh_node_boundary.nodeList());

          bc_list.push_back(VelocityBoundaryCondition{mesh_node_boundary.nodeList(), value_list});
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      case IBoundaryConditionDescriptor::Type::free: {   // <- Free boundary
        const FreeBoundaryConditionDescriptor& free_bc_descriptor =
          dynamic_cast<const FreeBoundaryConditionDescriptor&>(*bc_descriptor);

        MeshNodeBoundary mesh_node_boundary = getMeshNodeBoundary(*mesh, free_bc_descriptor.boundaryDescriptor());

        bc_list.push_back(FreeBoundaryCondition{mesh_node_boundary.nodeList()});
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for this solver";
        throw NormalError(error_msg.str());
      }
    }

    return bc_list;
  }

  ImplicitExactTrafficSolver(const std::shared_ptr<const MeshType>& p_mesh,
                             const DiscreteScalarFunction tau,
                             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                             const double& dt)
    : m_boundary_condition_list{this->_getBCList(p_mesh, bc_descriptor_list)}, m_mesh{*p_mesh}
  {
    // const MeshType& mesh = *p_mesh;

    m_Mj                             = this->_getMj(tau);
    CellValue<const double> next_tau = this->_getExactTau(tau, dt);

    m_tau = [&] {
      CellValue<double> computed_tau(m_mesh.connectivity());
      parallel_for(
        m_mesh.numberOfCells(), PUGS_LAMBDA(CellId j) { computed_tau[j] = next_tau[j]; });
      return computed_tau;
    }();
  }

 public:
  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const MeshType>& mesh, const DiscreteScalarFunction& tau) const
  {
    const NodeValue<Rd> new_position = [&]() {
      const auto& node_to_cell_matrix = mesh->connectivity().nodeToCellMatrix();

      NodeValue<Rd> x_next(mesh->connectivity());

      {
        NodeId r               = 0;
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[0];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / m_tau[j]};
      }

      for (NodeId r = 1; r < mesh->connectivity().numberOfNodes() - 1; ++r) {
        const auto& node_cells = node_to_cell_matrix[r];
        CellId j               = node_cells[1];
        x_next[r]              = mesh->xr()[r] + dt * Rd{1 - 1. / m_tau[j]};
      }

      for (const auto& boundary_condition : m_boundary_condition_list) {
        std::visit(
          [&](auto&& bc) {
            using T = std::decay_t<decltype(bc)>;
            if constexpr (std::is_same_v<VelocityBoundaryCondition, T>) {
              const auto& node_list  = bc.nodeList();
              const auto& value_list = bc.valueList();
              for (size_t i_node = 0; i_node < node_list.size(); ++i_node) {
                const NodeId& node_id = node_list[i_node];

                x_next[node_id] = mesh->xr()[node_id] + dt * value_list[i_node];
              }
            }
          },
          boundary_condition);
      }
      return x_next;
    }();

    /*for (NodeId r = 0; r < mesh->connectivity().numberOfNodes(); ++r) {
      std::cout << "nouveau xj+12[" << r << "]=" << new_x[r] << '\n';
      std::cout << "nouvelle position xj+12[" << r << "]=" << new_position[r] << '\n';
      }*/

    std::shared_ptr<const MeshType> new_mesh = std::make_shared<MeshType>(mesh->shared_connectivity(), new_position);

    CellValue<double> new_tau = copy(tau.cellValues());
    parallel_for(
      new_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { new_tau[j] = m_tau[j]; });

    return {std::make_shared<MeshVariant>(new_mesh),
            std::make_shared<DiscreteFunctionVariant>(DiscreteScalarFunction{new_mesh, new_tau})};
  }

  std::tuple<std::shared_ptr<const MeshVariant>, std::shared_ptr<const DiscreteFunctionVariant>>
  apply(const double& dt, const std::shared_ptr<const DiscreteFunctionVariant>& tau) const
  {
    std::shared_ptr mesh_v = getCommonMesh({tau});
    if (not mesh_v) {
      throw NormalError("discrete function is not defined on the same mesh");
    }
    if (not checkDiscretizationType({tau}, DiscreteFunctionType::P0)) {
      throw NormalError("implicit traffic solver expects P0 functions");
    }
    return this->apply(dt, mesh_v->get<MeshType>(), tau->get<DiscreteScalarFunction>());
  }

  ImplicitExactTrafficSolver(const std::shared_ptr<const MeshVariant>& mesh_v,
                             const std::shared_ptr<const DiscreteFunctionVariant>& tau,
                             const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
                             const double& dt)
    : ImplicitExactTrafficSolver(mesh_v->get<MeshType>(), tau->get<DiscreteScalarFunction>(), bc_descriptor_list, dt)
  {}

  ImplicitExactTrafficSolver()                             = default;
  ImplicitExactTrafficSolver(ImplicitExactTrafficSolver&&) = default;
  ~ImplicitExactTrafficSolver()                            = default;
};

template <MeshConcept MeshType>
class ImplicitExactTrafficSolverHandler::ImplicitExactTrafficSolver<MeshType>::VelocityBoundaryCondition
{
 private:
  const Array<const TinyVector<Dimension>> m_value_list;
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  const Array<const TinyVector<Dimension>>&
  valueList() const
  {
    return m_value_list;
  }

  VelocityBoundaryCondition(const Array<const NodeId>& node_list, const Array<const TinyVector<Dimension>>& value_list)
    : m_value_list{value_list}, m_node_list{node_list}
  {}

  ~VelocityBoundaryCondition() = default;
};

template <MeshConcept MeshType>
class ImplicitExactTrafficSolverHandler::ImplicitExactTrafficSolver<MeshType>::FreeBoundaryCondition
{
 private:
  const Array<const NodeId> m_node_list;

 public:
  const Array<const NodeId>&
  nodeList() const
  {
    return m_node_list;
  }

  FreeBoundaryCondition(const Array<const NodeId>& node_list) : m_node_list{node_list} {}

  ~FreeBoundaryCondition() = default;
};

ImplicitExactTrafficSolverHandler::ImplicitExactTrafficSolverHandler(
  const std::shared_ptr<const DiscreteFunctionVariant>& tau_v,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const double& dt)
{
  std::shared_ptr mesh_v = getCommonMesh({tau_v});
  if (not mesh_v) {
    throw NormalError("discrete function is not defined on the same mesh");
  }

  if (not checkDiscretizationType({tau_v}, DiscreteFunctionType::P0)) {
    throw NormalError("Implicit solver expects P0 functions");
  }

  std::visit(
    [&](auto&& mesh) {
      using MeshType = mesh_type_t<decltype(mesh)>;
      if constexpr (std::is_same_v<MeshType, Mesh<1>>) {
        m_implicit_exact_traffic_solver =
          std::make_unique<ImplicitExactTrafficSolver<MeshType>>(mesh_v, tau_v, bc_descriptor_list, dt);
      } else {
        throw NormalError("unexpected mesh type");
      }
    },
    mesh_v->variant());
}
