#include <language/modules/SchemeModule.hpp>

#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/TypeDescriptor.hpp>
#include <mesh/Mesh.hpp>
#include <scheme/AcousticSolver.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryDescriptor.hpp>
#include <scheme/NamedBoundaryDescriptor.hpp>
#include <scheme/NumberedBoundaryDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>

#include <memory>

/////////// TEMPORARY

#include <output/VTKWriter.hpp>

template <size_t Dimension>
struct GlaceScheme
{
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<MeshType>;
  using UnknownsType     = FiniteVolumesEulerUnknowns<MeshDataType>;

  const MeshType& m_mesh;

  GlaceScheme(const IMesh& mesh, std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>> bc_descriptor_list)
    : m_mesh{dynamic_cast<const MeshType&>(mesh)}
  {
    MeshDataType mesh_data(m_mesh);

    std::cout << "number of bc descr = " << bc_descriptor_list.size() << '\n';

    std::vector<BoundaryConditionHandler> bc_list;
    {
      constexpr ItemType FaceType = [] {
        if constexpr (Dimension > 1) {
          return ItemType::face;
        } else {
          return ItemType::node;
        }
      }();

      for (const auto& bc_descriptor : bc_descriptor_list) {
        switch (bc_descriptor->type()) {
        case IBoundaryConditionDescriptor::Type::symmetry: {
          const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
            dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
          for (size_t i_ref_face_list = 0;
               i_ref_face_list < m_mesh.connectivity().template numberOfRefItemList<FaceType>(); ++i_ref_face_list) {
            const auto& ref_face_list = m_mesh.connectivity().template refItemList<FaceType>(i_ref_face_list);
            const RefId& ref          = ref_face_list.refId();
            if (ref == sym_bc_descriptor.boundaryDescriptor()) {
              SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc =
                new SymmetryBoundaryCondition<MeshType::Dimension>(
                  MeshFlatNodeBoundary<MeshType::Dimension>(m_mesh, ref_face_list));
              std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc);
              bc_list.push_back(BoundaryConditionHandler(bc));
            }
          }
          break;
        }
        default: {
          throw UnexpectedError("Unknown BCDescription\n");
        }
        }
      }
    }

    UnknownsType unknowns(mesh_data);
    unknowns.initializeSod();

    AcousticSolver<MeshDataType> acoustic_solver(mesh_data, bc_list);

    const CellValue<const double>& Vj = mesh_data.Vj();

    const double tmax = 0.2;
    double t          = 0;

    int itermax   = std::numeric_limits<int>::max();
    int iteration = 0;

    CellValue<double>& rhoj   = unknowns.rhoj();
    CellValue<double>& ej     = unknowns.ej();
    CellValue<double>& pj     = unknowns.pj();
    CellValue<double>& gammaj = unknowns.gammaj();
    CellValue<double>& cj     = unknowns.cj();

    BlockPerfectGas block_eos(rhoj, ej, pj, gammaj, cj);

    VTKWriter vtk_writer("mesh_" + std::to_string(Dimension), 0.01);

    while ((t < tmax) and (iteration < itermax)) {
      vtk_writer.write(m_mesh,
                       {NamedItemValue{"density", rhoj}, NamedItemValue{"velocity", unknowns.uj()},
                        NamedItemValue{"coords", m_mesh.xr()},
                        NamedItemValue{"cell_owner", m_mesh.connectivity().cellOwner()},
                        NamedItemValue{"node_owner", m_mesh.connectivity().nodeOwner()}},
                       t);
      double dt = 0.4 * acoustic_solver.acoustic_dt(Vj, cj);
      if (t + dt > tmax) {
        dt = tmax - t;
      }
      acoustic_solver.computeNextStep(t, dt, unknowns);

      block_eos.updatePandCFromRhoE();

      t += dt;
      ++iteration;
    }
    vtk_writer.write(m_mesh,
                     {NamedItemValue{"density", rhoj}, NamedItemValue{"velocity", unknowns.uj()},
                      NamedItemValue{"coords", m_mesh.xr()},
                      NamedItemValue{"cell_owner", m_mesh.connectivity().cellOwner()},
                      NamedItemValue{"node_owner", m_mesh.connectivity().nodeOwner()}},
                     t, true);   // forces last output
  }
};

SchemeModule::SchemeModule()
{
  this->_addTypeDescriptor(
    std::make_shared<TypeDescriptor>(ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>>.typeName()));

  this->_addTypeDescriptor(std::make_shared<TypeDescriptor>(
    ast_node_data_type_from<std::shared_ptr<const IBoundaryConditionDescriptor>>.typeName()));

  this->_addBuiltinFunction("boundaryName",
                            std::make_shared<
                              BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryDescriptor>, std::string>>(
                              std::function<std::shared_ptr<const IBoundaryDescriptor>(std::string)>{

                                [](const std::string& boundary_name) -> std::shared_ptr<const IBoundaryDescriptor> {
                                  return std::make_shared<NamedBoundaryDescriptor>(boundary_name);
                                }}

                              ));

  this->_addBuiltinFunction("boundaryTag",
                            std::make_shared<
                              BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryDescriptor>, int64_t>>(
                              std::function<std::shared_ptr<const IBoundaryDescriptor>(int64_t)>{

                                [](int64_t boundary_tag) -> std::shared_ptr<const IBoundaryDescriptor> {
                                  return std::make_shared<NumberedBoundaryDescriptor>(boundary_tag);
                                }}

                              ));

  this
    ->_addBuiltinFunction("symmetry",
                          std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryConditionDescriptor>,
                                                                   std::shared_ptr<const IBoundaryDescriptor>>>(
                            std::function<std::shared_ptr<const IBoundaryConditionDescriptor>(
                              std::shared_ptr<const IBoundaryDescriptor>)>{

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary)
                                -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<SymmetryBoundaryConditionDescriptor>(boundary);
                              }}

                            ));

  this
    ->_addBuiltinFunction("glace",
                          std::make_shared<
                            BuiltinFunctionEmbedder<void, std::shared_ptr<const IMesh>,
                                                    std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>>>(
                            std::function<void(std::shared_ptr<const IMesh>,
                                               std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>)>{

                              [](std::shared_ptr<const IMesh> p_mesh,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list) -> void {
                                switch (p_mesh->dimension()) {
                                case 1: {
                                  GlaceScheme<1>{*p_mesh, bc_descriptor_list};
                                  break;
                                }
                                case 2: {
                                  GlaceScheme<2>{*p_mesh, bc_descriptor_list};
                                  break;
                                }
                                case 3: {
                                  GlaceScheme<3>{*p_mesh, bc_descriptor_list};
                                  break;
                                }
                                default: {
                                  throw UnexpectedError("invalid mesh dimension");
                                }
                                }
                              }

                            }));
}
