#include <PastisUtils.hpp>
#include <PastisOStream.hpp>

#include <rang.hpp>

#include <Connectivity.hpp>

#include <Mesh.hpp>
#include <BoundaryCondition.hpp>
#include <AcousticSolver.hpp>

#include <VTKWriter.hpp>

#include <Timer.hpp>

#include <TinyVector.hpp>
#include <TinyMatrix.hpp>

#include <BoundaryConditionDescriptor.hpp>

#include <MeshNodeBoundary.hpp>

#include <GmshReader.hpp>
#include <PastisParser.hpp>

#include <SynchronizerManager.hpp>

#include <limits>
#include <map>
#include <regex>

int main(int argc, char *argv[])
{
  std::string filename = initialize(argc, argv);

   std::regex gmsh_regex("(.*).msh");
   if (not std::regex_match(filename, gmsh_regex))  {
      parser(filename);
      return 0;
   }

  std::map<std::string, double> method_cost_map;

  SynchronizerManager::create();

  if (filename != "") {
    pout() << "Reading (gmsh) " << rang::style::underline << filename << rang::style::reset << " ...\n";
    Timer gmsh_timer;
    gmsh_timer.reset();
    GmshReader gmsh_reader(filename);
    method_cost_map["Mesh building"] = gmsh_timer.seconds();

    std::shared_ptr<IMesh> p_mesh = gmsh_reader.mesh();

    switch (p_mesh->dimension()) {
      case 1: {
        std::vector<std::string> sym_boundary_name_list = {"XMIN", "XMAX"};
        std::vector<std::shared_ptr<BoundaryConditionDescriptor>> bc_descriptor_list;
        for (const auto& sym_boundary_name : sym_boundary_name_list){
          std::shared_ptr<BoundaryDescriptor> boudary_descriptor
              = std::shared_ptr<BoundaryDescriptor>(new NamedBoundaryDescriptor(sym_boundary_name));
          SymmetryBoundaryConditionDescriptor* sym_bc_descriptor
              = new SymmetryBoundaryConditionDescriptor(boudary_descriptor);

          bc_descriptor_list.push_back(std::shared_ptr<BoundaryConditionDescriptor>(sym_bc_descriptor));
        }

        using ConnectivityType = Connectivity1D;
        using MeshType = Mesh<ConnectivityType>;
        using MeshDataType = MeshData<MeshType>;
        using UnknownsType = FiniteVolumesEulerUnknowns<MeshDataType>;

        const MeshType& mesh = dynamic_cast<const MeshType&>(*gmsh_reader.mesh());

        Timer timer;
        timer.reset();
        MeshDataType mesh_data(mesh);

        std::vector<BoundaryConditionHandler> bc_list;
        {
          for (const auto& bc_descriptor : bc_descriptor_list) {
            switch (bc_descriptor->type()) {
              case BoundaryConditionDescriptor::Type::symmetry: {
                const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor
                    = dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
                for (size_t i_ref_node_list=0; i_ref_node_list<mesh.connectivity().numberOfRefNodeList();
                     ++i_ref_node_list) {
                  const RefNodeList& ref_node_list = mesh.connectivity().refNodeList(i_ref_node_list);
                  const RefId& ref = ref_node_list.refId();
                  if (ref == sym_bc_descriptor.boundaryDescriptor()) {
                    SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc
                        = new SymmetryBoundaryCondition<MeshType::Dimension>(MeshFlatNodeBoundary<MeshType::Dimension>(mesh, ref_node_list));
                    std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc);
                    bc_list.push_back(BoundaryConditionHandler(bc));
                  }
                }
                break;
              }
              default: {
                perr() << "Unknown BCDescription\n";
                std::exit(1);
              }
            }
          }
        }

        UnknownsType unknowns(mesh_data);

        unknowns.initializeSod();

        AcousticSolver<MeshDataType> acoustic_solver(mesh_data, bc_list);

        using Rd = TinyVector<MeshType::Dimension>;

        const CellValue<const double>& Vj = mesh_data.Vj();

        const double tmax=0.2;
        double t=0;

        int itermax=std::numeric_limits<int>::max();
        int iteration=0;

        CellValue<double>& rhoj = unknowns.rhoj();
        CellValue<double>& ej = unknowns.ej();
        CellValue<double>& pj = unknowns.pj();
        CellValue<double>& gammaj = unknowns.gammaj();
        CellValue<double>& cj = unknowns.cj();

        BlockPerfectGas block_eos(rhoj, ej, pj, gammaj, cj);

        VTKWriter vtk_writer("mesh", 0.01);

        while((t<tmax) and (iteration<itermax)) {
          vtk_writer.write(mesh, {NamedItemValue{"density", rhoj},
                                  NamedItemValue{"velocity", unknowns.uj()},
                                  NamedItemValue{"coords", mesh.xr()},
                                  NamedItemValue{"cell_owner", mesh.connectivity().cellOwner()},
                                  NamedItemValue{"node_owner", mesh.connectivity().nodeOwner()}},t);
          double dt = 0.4*acoustic_solver.acoustic_dt(Vj, cj);
          if (t+dt>tmax) {
            dt=tmax-t;
          }
          acoustic_solver.computeNextStep(t,dt, unknowns);

          block_eos.updatePandCFromRhoE();

          t += dt;
          ++iteration;
        }
        vtk_writer.write(mesh, {NamedItemValue{"density", rhoj},
                                NamedItemValue{"velocity", unknowns.uj()},
                                NamedItemValue{"coords", mesh.xr()},
                                NamedItemValue{"cell_owner", mesh.connectivity().cellOwner()},
                                NamedItemValue{"node_owner", mesh.connectivity().nodeOwner()}}, t, true); // forces last output

        pout() << "* " << rang::style::underline << "Final time" << rang::style::reset
               << ":  " << rang::fgB::green << t << rang::fg::reset << " (" << iteration << " iterations)\n";

        method_cost_map["AcousticSolverWithMesh"] = timer.seconds();

        { // gnuplot output for density
          const CellValue<const Rd>& xj = mesh_data.xj();
          const CellValue<const double>& rhoj = unknowns.rhoj();
          std::ofstream fout("rho");
          for (CellId j=0; j<mesh.numberOfCells(); ++j) {
            fout << xj[j][0] << ' ' << rhoj[j] << '\n';
          }
        }

        break;
      }
      case 2: {
        // test case boundary condition description
        std::vector<std::string> sym_boundary_name_list = {"XMIN", "XMAX", "YMIN", "YMAX"};
        std::vector<std::shared_ptr<BoundaryConditionDescriptor>> bc_descriptor_list;
        for (const auto& sym_boundary_name : sym_boundary_name_list){
          std::shared_ptr<BoundaryDescriptor> boudary_descriptor
              = std::shared_ptr<BoundaryDescriptor>(new NamedBoundaryDescriptor(sym_boundary_name));
          SymmetryBoundaryConditionDescriptor* sym_bc_descriptor
              = new SymmetryBoundaryConditionDescriptor(boudary_descriptor);

          bc_descriptor_list.push_back(std::shared_ptr<BoundaryConditionDescriptor>(sym_bc_descriptor));
        }

        using ConnectivityType = Connectivity2D;
        using MeshType = Mesh<ConnectivityType>;
        using MeshDataType = MeshData<MeshType>;
        using UnknownsType = FiniteVolumesEulerUnknowns<MeshDataType>;

        const MeshType& mesh = dynamic_cast<const MeshType&>(*gmsh_reader.mesh());

        Timer timer;
        timer.reset();
        MeshDataType mesh_data(mesh);

        std::vector<BoundaryConditionHandler> bc_list;
        {
          for (const auto& bc_descriptor : bc_descriptor_list) {
            switch (bc_descriptor->type()) {
              case BoundaryConditionDescriptor::Type::symmetry: {
                const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor
                    = dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
                for (size_t i_ref_face_list=0; i_ref_face_list<mesh.connectivity().numberOfRefFaceList();
                     ++i_ref_face_list) {
                  const RefFaceList& ref_face_list = mesh.connectivity().refFaceList(i_ref_face_list);
                  const RefId& ref = ref_face_list.refId();
                  if (ref == sym_bc_descriptor.boundaryDescriptor()) {
                    SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc
                        = new SymmetryBoundaryCondition<MeshType::Dimension>(MeshFlatNodeBoundary<MeshType::Dimension>(mesh, ref_face_list));
                    std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc);
                    bc_list.push_back(BoundaryConditionHandler(bc));
                  }
                }
                break;
              }
              default: {
                perr() << "Unknown BCDescription\n";
                std::exit(1);
              }
            }
          }
        }

        UnknownsType unknowns(mesh_data);

        unknowns.initializeSod();

        AcousticSolver<MeshDataType> acoustic_solver(mesh_data, bc_list);

        const CellValue<const double>& Vj = mesh_data.Vj();

        const double tmax=0.2;
        double t=0;

        int itermax=std::numeric_limits<int>::max();
        int iteration=0;

        CellValue<double>& rhoj = unknowns.rhoj();
        CellValue<double>& ej = unknowns.ej();
        CellValue<double>& pj = unknowns.pj();
        CellValue<double>& gammaj = unknowns.gammaj();
        CellValue<double>& cj = unknowns.cj();

        BlockPerfectGas block_eos(rhoj, ej, pj, gammaj, cj);

        VTKWriter vtk_writer("mesh", 0.01);

        while((t<tmax) and (iteration<itermax)) {
          vtk_writer.write(mesh, {NamedItemValue{"density", rhoj},
                                  NamedItemValue{"velocity", unknowns.uj()},
                                  NamedItemValue{"coords", mesh.xr()},
                                  NamedItemValue{"cell_owner", mesh.connectivity().cellOwner()},
                                  NamedItemValue{"node_owner", mesh.connectivity().nodeOwner()}}, t);
          double dt = 0.4*acoustic_solver.acoustic_dt(Vj, cj);
          if (t+dt>tmax) {
            dt=tmax-t;
          }
          acoustic_solver.computeNextStep(t,dt, unknowns);

          block_eos.updatePandCFromRhoE();

          t += dt;
          ++iteration;
        }
        vtk_writer.write(mesh, {NamedItemValue{"density", rhoj},
                                NamedItemValue{"velocity", unknowns.uj()},
                                NamedItemValue{"coords", mesh.xr()},
                                NamedItemValue{"cell_owner", mesh.connectivity().cellOwner()},
                                NamedItemValue{"node_owner", mesh.connectivity().nodeOwner()}}, t, true); // forces last output

        pout() << "* " << rang::style::underline << "Final time" << rang::style::reset
               << ":  " << rang::fgB::green << t << rang::fg::reset << " (" << iteration << " iterations)\n";

        method_cost_map["AcousticSolverWithMesh"] = timer.seconds();
        break;
      }
      case 3: {
        std::vector<std::string> sym_boundary_name_list = {"XMIN", "XMAX", "YMIN", "YMAX", "ZMIN", "ZMAX"};
        std::vector<std::shared_ptr<BoundaryConditionDescriptor>> bc_descriptor_list;
        for (const auto& sym_boundary_name : sym_boundary_name_list){
          std::shared_ptr<BoundaryDescriptor> boudary_descriptor
              = std::shared_ptr<BoundaryDescriptor>(new NamedBoundaryDescriptor(sym_boundary_name));
          SymmetryBoundaryConditionDescriptor* sym_bc_descriptor
              = new SymmetryBoundaryConditionDescriptor(boudary_descriptor);

          bc_descriptor_list.push_back(std::shared_ptr<BoundaryConditionDescriptor>(sym_bc_descriptor));
        }

        using ConnectivityType = Connectivity3D;
        using MeshType = Mesh<ConnectivityType>;
        using MeshDataType = MeshData<MeshType>;
        using UnknownsType = FiniteVolumesEulerUnknowns<MeshDataType>;

        const MeshType& mesh = dynamic_cast<const MeshType&>(*gmsh_reader.mesh());

        Timer timer;
        timer.reset();
        MeshDataType mesh_data(mesh);

        std::vector<BoundaryConditionHandler> bc_list;
        {
          for (const auto& bc_descriptor : bc_descriptor_list) {
            switch (bc_descriptor->type()) {
              case BoundaryConditionDescriptor::Type::symmetry: {
                const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor
                    = dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
                for (size_t i_ref_face_list=0; i_ref_face_list<mesh.connectivity().numberOfRefFaceList();
                     ++i_ref_face_list) {
                  const RefFaceList& ref_face_list = mesh.connectivity().refFaceList(i_ref_face_list);
                  const RefId& ref = ref_face_list.refId();
                  if (ref == sym_bc_descriptor.boundaryDescriptor()) {
                    SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc
                        = new SymmetryBoundaryCondition<MeshType::Dimension>(MeshFlatNodeBoundary<MeshType::Dimension>(mesh, ref_face_list));
                    std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc);
                    bc_list.push_back(BoundaryConditionHandler(bc));
                  }
                }
                break;
              }
              default: {
                perr() << "Unknown BCDescription\n";
                std::exit(1);
              }
            }
          }
        }

        UnknownsType unknowns(mesh_data);

        unknowns.initializeSod();

        AcousticSolver<MeshDataType> acoustic_solver(mesh_data, bc_list);

        const CellValue<const double>& Vj = mesh_data.Vj();

        const double tmax=0.2;
        double t=0;

        int itermax=std::numeric_limits<int>::max();
        int iteration=0;

        CellValue<double>& rhoj = unknowns.rhoj();
        CellValue<double>& ej = unknowns.ej();
        CellValue<double>& pj = unknowns.pj();
        CellValue<double>& gammaj = unknowns.gammaj();
        CellValue<double>& cj = unknowns.cj();

        BlockPerfectGas block_eos(rhoj, ej, pj, gammaj, cj);

        VTKWriter vtk_writer("mesh", 0.01);

        while((t<tmax) and (iteration<itermax)) {
          vtk_writer.write(mesh, {NamedItemValue{"density", rhoj},
                                  NamedItemValue{"velocity", unknowns.uj()},
                                  NamedItemValue{"coords", mesh.xr()},
                                  NamedItemValue{"cell_owner", mesh.connectivity().cellOwner()},
                                  NamedItemValue{"node_owner", mesh.connectivity().nodeOwner()}}, t);
          double dt = 0.4*acoustic_solver.acoustic_dt(Vj, cj);
          if (t+dt>tmax) {
            dt=tmax-t;
          }
          acoustic_solver.computeNextStep(t,dt, unknowns);
          block_eos.updatePandCFromRhoE();

          t += dt;
          ++iteration;
        }
        vtk_writer.write(mesh, {NamedItemValue{"density", rhoj},
                                NamedItemValue{"velocity", unknowns.uj()},
                                NamedItemValue{"coords", mesh.xr()},
                                NamedItemValue{"cell_owner", mesh.connectivity().cellOwner()},
                                NamedItemValue{"node_owner", mesh.connectivity().nodeOwner()}}, t, true); // forces last output

        pout() << "* " << rang::style::underline << "Final time" << rang::style::reset
               << ":  " << rang::fgB::green << t << rang::fg::reset << " (" << iteration << " iterations)\n";

        method_cost_map["AcousticSolverWithMesh"] = timer.seconds();
        break;
      }
    }

    pout() << "* "  << rang::fgB::red << "Could not be uglier!" << rang::fg::reset << " (" << __FILE__ << ':' << __LINE__ << ")\n";

  } else {
    perr() << "Connectivity1D defined by number of nodes no more implemented\n";
    std::exit(0);
  }

  SynchronizerManager::destroy();

  finalize();

  std::string::size_type size=0;
  for (const auto& method_cost : method_cost_map) {
    size = std::max(size, method_cost.first.size());
  }

  for (const auto& method_cost : method_cost_map) {
    pout() << "* ["
           << rang::fgB::cyan
           << std::setw(size) << std::left
           << method_cost.first
           << rang::fg::reset
           << "] Execution time: "
           << rang::style::bold
           << method_cost.second
           << rang::style::reset << '\n';
  }

  return 0;
}