#include <language/algorithms/AcousticSolverAlgorithm.hpp>

#include <language/utils/InterpolateItemValue.hpp>
#include <output/VTKWriter.hpp>
#include <scheme/AcousticSolver.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

template <size_t Dimension>
AcousticSolverAlgorithm<Dimension>::AcousticSolverAlgorithm(
  const AcousticSolverType& solver_type,
  std::shared_ptr<const IMesh> i_mesh,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
  const FunctionSymbolId& rho_id,
  const FunctionSymbolId& u_id,
  const FunctionSymbolId& p_id)
{
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;
  using UnknownsType     = FiniteVolumesEulerUnknowns<MeshType>;

  std::shared_ptr<const MeshType> mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);

  typename AcousticSolver<MeshType>::BoundaryConditionList bc_list;
  {
    constexpr ItemType FaceType = [] {
      if constexpr (Dimension > 1) {
        return ItemType::face;
      } else {
        return ItemType::node;
      }
    }();

    for (const auto& bc_descriptor : bc_descriptor_list) {
      bool is_valid_boundary_condition = true;

      switch (bc_descriptor->type()) {
      case IBoundaryConditionDescriptor::Type::symmetry: {
        using SymmetryBoundaryCondition = typename AcousticSolver<MeshType>::SymmetryBoundaryCondition;

        const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
          dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
        for (size_t i_ref_face_list = 0;
             i_ref_face_list < mesh->connectivity().template numberOfRefItemList<FaceType>(); ++i_ref_face_list) {
          const auto& ref_face_list = mesh->connectivity().template refItemList<FaceType>(i_ref_face_list);
          const RefId& ref          = ref_face_list.refId();
          if (ref == sym_bc_descriptor.boundaryDescriptor()) {
            bc_list.push_back(
              SymmetryBoundaryCondition{MeshFlatNodeBoundary<MeshType::Dimension>(mesh, ref_face_list)});
          }
        }
        is_valid_boundary_condition = true;
        break;
      }

      case IBoundaryConditionDescriptor::Type::dirichlet: {
        const DirichletBoundaryConditionDescriptor& dirichlet_bc_descriptor =
          dynamic_cast<const DirichletBoundaryConditionDescriptor&>(*bc_descriptor);
        if (dirichlet_bc_descriptor.name() == "velocity") {
          using VelocityBoundaryCondition = typename AcousticSolver<MeshType>::VelocityBoundaryCondition;

          for (size_t i_ref_face_list = 0;
               i_ref_face_list < mesh->connectivity().template numberOfRefItemList<FaceType>(); ++i_ref_face_list) {
            const auto& ref_face_list = mesh->connectivity().template refItemList<FaceType>(i_ref_face_list);
            const RefId& ref          = ref_face_list.refId();
            if (ref == dirichlet_bc_descriptor.boundaryDescriptor()) {
              MeshNodeBoundary<Dimension> mesh_node_boundary{mesh, ref_face_list};

              const FunctionSymbolId velocity_id = dirichlet_bc_descriptor.rhsSymbolId();

              const auto& node_list = mesh_node_boundary.nodeList();

              Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
                TinyVector<Dimension>)>::template interpolate<ItemType::node>(velocity_id, mesh->xr(), node_list);

              bc_list.push_back(VelocityBoundaryCondition{node_list, value_list});
            }
          }
        } else if (dirichlet_bc_descriptor.name() == "pressure") {
          using PressureBoundaryCondition = typename AcousticSolver<MeshType>::PressureBoundaryCondition;

          for (size_t i_ref_face_list = 0;
               i_ref_face_list < mesh->connectivity().template numberOfRefItemList<FaceType>(); ++i_ref_face_list) {
            const auto& ref_face_list = mesh->connectivity().template refItemList<FaceType>(i_ref_face_list);
            const RefId& ref          = ref_face_list.refId();
            if (ref == dirichlet_bc_descriptor.boundaryDescriptor()) {
              const auto& face_list = ref_face_list.list();

              const FunctionSymbolId pressure_id = dirichlet_bc_descriptor.rhsSymbolId();

              Array<const double> face_values = [&] {
                if constexpr (Dimension == 1) {
                  return InterpolateItemValue<double(
                    TinyVector<Dimension>)>::template interpolate<FaceType>(pressure_id, mesh->xr(), face_list);
                } else {
                  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

                  return InterpolateItemValue<double(
                    TinyVector<Dimension>)>::template interpolate<FaceType>(pressure_id, mesh_data.xl(), face_list);
                }
              }();

              bc_list.push_back(PressureBoundaryCondition{face_list, face_values});
            }
          }
        } else {
          is_valid_boundary_condition = false;
        }
        break;
      }
      default: {
        is_valid_boundary_condition = false;
      }
      }
      if (not is_valid_boundary_condition) {
        std::ostringstream error_msg;
        error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
        throw NormalError(error_msg.str());
      }
    }
  }

  UnknownsType unknowns(*mesh);

  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    unknowns.rhoj() =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(rho_id, mesh_data.xj());

    unknowns.pj() =
      InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(p_id, mesh_data.xj());

    unknowns.uj() =
      InterpolateItemValue<TinyVector<Dimension>(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(u_id,
                                                                                                               mesh_data
                                                                                                                 .xj());
  }
  unknowns.gammaj().fill(1.4);

  AcousticSolver acoustic_solver(mesh, bc_list, solver_type);

  const double tmax = 0.2;
  double t          = 0;

  int itermax   = std::numeric_limits<int>::max();
  int iteration = 0;

  CellValue<double>& rhoj              = unknowns.rhoj();
  CellValue<double>& ej                = unknowns.ej();
  CellValue<double>& pj                = unknowns.pj();
  CellValue<double>& gammaj            = unknowns.gammaj();
  CellValue<double>& cj                = unknowns.cj();
  CellValue<TinyVector<Dimension>>& uj = unknowns.uj();
  CellValue<double>& Ej                = unknowns.Ej();
  CellValue<double>& mj                = unknowns.mj();
  CellValue<double>& inv_mj            = unknowns.invMj();

  BlockPerfectGas block_eos(rhoj, ej, pj, gammaj, cj);
  block_eos.updateEandCFromRhoP();

  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    const CellValue<const double> Vj = mesh_data.Vj();

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { Ej[j] = ej[j] + 0.5 * (uj[j], uj[j]); });

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { mj[j] = rhoj[j] * Vj[j]; });

    parallel_for(
      mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { inv_mj[j] = 1. / mj[j]; });
  }

  VTKWriter vtk_writer("mesh_" + std::to_string(Dimension), 0.01);

  while ((t < tmax) and (iteration < itermax)) {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    vtk_writer.write(mesh,
                     {NamedItemValue{"density", rhoj}, NamedItemValue{"velocity", unknowns.uj()},
                      NamedItemValue{"coords", mesh->xr()}, NamedItemValue{"xj", mesh_data.xj()},
                      NamedItemValue{"cell_owner", mesh->connectivity().cellOwner()},
                      NamedItemValue{"node_owner", mesh->connectivity().nodeOwner()}},
                     t);

    double dt = 0.4 * acoustic_solver.acoustic_dt(mesh_data.Vj(), cj);
    if (t + dt > tmax) {
      dt = tmax - t;
    }

    std::cout.setf(std::cout.scientific);
    std::cout << "iteration " << rang::fg::cyan << std::setw(4) << iteration << rang::style::reset
              << " time=" << rang::fg::green << t << rang::style::reset << " dt=" << rang::fgB::blue << dt
              << rang::style::reset << '\n';

    mesh = acoustic_solver.computeNextStep(dt, unknowns);

    block_eos.updatePandCFromRhoE();

    t += dt;
    ++iteration;
  }
  std::cout << rang::style::bold << "Final time=" << rang::fgB::green << t << rang::style::reset << " reached after "
            << rang::fgB::cyan << iteration << rang::style::reset << rang::style::bold << " iterations"
            << rang::style::reset << '\n';
  {
    MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*mesh);

    vtk_writer.write(mesh,
                     {NamedItemValue{"density", rhoj}, NamedItemValue{"velocity", unknowns.uj()},
                      NamedItemValue{"coords", mesh->xr()}, NamedItemValue{"xj", mesh_data.xj()},
                      NamedItemValue{"cell_owner", mesh->connectivity().cellOwner()},
                      NamedItemValue{"node_owner", mesh->connectivity().nodeOwner()}},
                     t, true);   // forces last output
  }
}

template AcousticSolverAlgorithm<1>::AcousticSolverAlgorithm(
  const AcousticSolverType&,
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);

template AcousticSolverAlgorithm<2>::AcousticSolverAlgorithm(
  const AcousticSolverType&,
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);

template AcousticSolverAlgorithm<3>::AcousticSolverAlgorithm(
  const AcousticSolverType&,
  std::shared_ptr<const IMesh>,
  const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
  const FunctionSymbolId&,
  const FunctionSymbolId&,
  const FunctionSymbolId&);
