#include <scheme/FluxingAdvectionSolver.hpp>

#include <language/utils/EvaluateAtPoints.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/IMesh.hpp>
#include <mesh/Mesh.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>

template <size_t Dimension>
class FluxingAdvectionSolver
{
 private:
  using Rd = TinyVector<Dimension>;

  using MeshType = Mesh<Connectivity<Dimension>>;

  const std::shared_ptr<const MeshType> m_old_mesh;
  const std::shared_ptr<const MeshType> m_new_mesh;
  const DiscreteFunctionP0<Dimension, const double> m_old_q;

 public:
  // CellValue<double>
  // compute_PFnp1(const DiscreteFunctionP0<Dimension, const double> F, const double& dt, const double& dx)
  // {
  //   CellValue<double> PFnp1{m_mesh.connectivity()};

  //   DiscreteFunctionP0<Dimension, double> deltaF  = compute_delta2Fn(F);
  //   DiscreteFunctionP0<Dimension, double> deltaF0 = compute_delta2Fn(m_Fn);

  //   for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
  //     PFnp1[cell_id] = m_Fn[cell_id][0] + m_Fn[cell_id][1] -
  //                      (0.5 * dt / dx) * m_lambda * (deltaF[cell_id][0] - deltaF[cell_id][1]) -
  //                      (0.5 * dt / dx) * m_lambda * (deltaF0[cell_id][0] - deltaF0[cell_id][1]);
  //   }
  //   return PFnp1;
  // }

  // DiscreteFunctionP0<Dimension, double>
  // apply(const double& dt, const double& eps)
  // {
  //   const DiscreteFunctionP0<Dimension, const double>& F0 = m_Fn;
  //   DiscreteFunctionP0<Dimension, double> Fnp1            = copy(F0);
  //   DiscreteFunctionP0<Dimension, double> deltaFn         = compute_delta2Fn(F0);

  //   for (size_t p = 0; p < 2; ++p) {
  //     CellId first_cell_id = 0;
  //     const double dx      = m_dx_table[first_cell_id];

  //     DiscreteFunctionP0<Dimension, double> deltaFnp1 = compute_delta2Fn(Fnp1);

  //     const CellValue<const double> PFnp1  = compute_PFnp1(Fnp1, dt, dx);
  //     const CellValue<const double> APFnp1 = getA(PFnp1);
  //     const CellArray<const double> MPFnp1 = compute_M(PFnp1, APFnp1);
  //     const CellValue<const double> PFn    = compute_PFn(F0);
  //     const CellValue<const double> APFn   = getA(PFn);
  //     const CellArray<const double> MPFn   = compute_M(PFn, APFn);

  //     for (CellId cell_id = 0; cell_id < m_mesh.numberOfCells(); ++cell_id) {
  //       Fnp1[cell_id][0] = 1. / (1 + 0.5 * dt / eps) *
  //                          ((0.5 * dt / eps) * MPFnp1[cell_id][0] + F0[cell_id][0] -
  //                           (0.5 * dt / dx) * m_lambda * (deltaFnp1[cell_id][0] + deltaFn[cell_id][0]) +
  //                           (0.5 * dt / eps) * (MPFn[cell_id][0] - F0[cell_id][0]));
  //       Fnp1[cell_id][1] = 1. / (1 + 0.5 * dt / eps) *
  //                          ((0.5 * dt / eps) * MPFnp1[cell_id][1] + F0[cell_id][1] +
  //                           (0.5 * dt / dx) * m_lambda * (deltaFnp1[cell_id][1] + deltaFn[cell_id][1]) +
  //                           (0.5 * dt / eps) * (MPFn[cell_id][1] - F0[cell_id][1]));
  //     }
  //   }

  //   return Fnp1;
  // }

  DiscreteFunctionP0<Dimension, double>
  apply()
  {
    DiscreteFunctionP0<Dimension, double> new_q(m_new_mesh);
    if (m_new_mesh->shared_connectivity() != m_old_mesh->shared_connectivity()) {
      throw NormalError("Old and new meshes must share the same connectivity");
    }
    return new_q;
  }

  FluxingAdvectionSolver(const std::shared_ptr<const MeshType> old_mesh,
                         const std::shared_ptr<const MeshType> new_mesh,
                         const DiscreteFunctionP0<Dimension, const double>& old_q)
    : m_old_mesh{old_mesh}, m_new_mesh{new_mesh}
  {}

  ~FluxingAdvectionSolver() = default;
};

std::shared_ptr<const DiscreteFunctionVariant>
FluxingAdvectionSolverHandler(const std::shared_ptr<const IMesh> new_mesh,
                              const std::shared_ptr<const DiscreteFunctionVariant>& old_q_v)
{
  const std::shared_ptr<const IMesh> old_mesh = getCommonMesh({old_q_v});

  if (not checkDiscretizationType({old_q_v}, DiscreteFunctionType::P0)) {
    throw NormalError("acoustic solver expects P0 functions");
  }

  switch (old_mesh->dimension()) {
  case 1: {
    constexpr size_t Dimension = 1;
    using MeshType             = Mesh<Connectivity<Dimension>>;

    const std::shared_ptr<const MeshType> old_mesh0 = std::dynamic_pointer_cast<const MeshType>(old_mesh);

    const DiscreteFunctionP0<Dimension, const double>& old_q =
      old_q_v->get<DiscreteFunctionP0<Dimension, const double>>();

    const std::shared_ptr<const MeshType> new_mesh0 = std::dynamic_pointer_cast<const MeshType>(new_mesh);

    FluxingAdvectionSolver<Dimension> solver(old_mesh0, new_mesh0, old_q);

    return std::make_shared<DiscreteFunctionVariant>(solver.apply());
  }
  case 2: {
    throw NotImplementedError("Fluxing advection solver not implemented in dimension 2");
  }
  case 3: {
    throw NotImplementedError("Fluxing advection solver not implemented in dimension 3");
  }
  default: {
    throw UnexpectedError("Invalid mesh dimension");
  }
  }
}