#include <language/modules/SchemeModule.hpp>

#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/TypeDescriptor.hpp>
#include <mesh/Mesh.hpp>
#include <scheme/AcousticSolver.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryDescriptor.hpp>
#include <scheme/NamedBoundaryDescriptor.hpp>
#include <scheme/NumberedBoundaryDescriptor.hpp>
#include <scheme/PressureBoundaryConditionDescriptor.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
#include <scheme/VelocityBoundaryConditionDescriptor.hpp>

#include <memory>

/////////// TEMPORARY

#include <language/utils/PugsFunctionAdapter.hpp>
#include <output/VTKWriter.hpp>

template <typename T>
class InterpolateItemValue;
template <typename OutputType, typename InputType>
class InterpolateItemValue<OutputType(InputType)> : public PugsFunctionAdapter<OutputType(InputType)>
{
  static constexpr size_t Dimension = OutputType::Dimension;
  using Adapter                     = PugsFunctionAdapter<OutputType(InputType)>;

 public:
  template <ItemType item_type>
  static inline ItemValue<OutputType, item_type>
  interpolate(const FunctionSymbolId& function_symbol_id, const ItemValue<const InputType, item_type>& position)
  {
    auto& expression    = Adapter::getFunctionExpression(function_symbol_id);
    auto convert_result = Adapter::getResultConverter(expression.m_data_type);

    Array<ExecutionPolicy> context_list = Adapter::getContextList(expression);

    using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
    Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;
    const IConnectivity& connectivity = *position.connectivity_ptr();

    ItemValue<OutputType, item_type> value(connectivity);
    using ItemId = ItemIdT<item_type>;

    parallel_for(connectivity.template numberOf<item_type>(), [=, &expression, &tokens](ItemId i) {
      const int32_t t = tokens.acquire();

      auto& execution_policy = context_list[t];

      Adapter::convertArgs(execution_policy.currentContext(), position[i]);
      auto result = expression.execute(execution_policy);
      value[i]    = convert_result(std::move(result));

      tokens.release(t);
    });

    return value;
  }

  template <ItemType item_type>
  static inline Array<OutputType>
  interpolate(const FunctionSymbolId& function_symbol_id,
              const ItemValue<const InputType, item_type>& position,
              const Array<const ItemIdT<item_type>>& list_of_items)
  {
    auto& expression    = Adapter::getFunctionExpression(function_symbol_id);
    auto convert_result = Adapter::getResultConverter(expression.m_data_type);

    Array<ExecutionPolicy> context_list = Adapter::getContextList(expression);

    using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
    Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;

    Array<OutputType> value{list_of_items.size()};
    using ItemId = ItemIdT<item_type>;

    parallel_for(list_of_items.size(), [=, &expression, &tokens](size_t i_item) {
      ItemId item_id  = list_of_items[i_item];
      const int32_t t = tokens.acquire();

      auto& execution_policy = context_list[t];

      Adapter::convertArgs(execution_policy.currentContext(), position[item_id]);
      auto result   = expression.execute(execution_policy);
      value[i_item] = convert_result(std::move(result));

      tokens.release(t);
    });

    return value;
  }
};

template <size_t Dimension>
struct GlaceScheme
{
  using ConnectivityType = Connectivity<Dimension>;
  using MeshType         = Mesh<ConnectivityType>;
  using MeshDataType     = MeshData<Dimension>;
  using UnknownsType     = FiniteVolumesEulerUnknowns<MeshType>;

  std::shared_ptr<const MeshType> m_mesh;

  GlaceScheme(std::shared_ptr<const IMesh> i_mesh,
              const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list,
              const FunctionSymbolId& rho_id,
              const FunctionSymbolId& u_id,
              const FunctionSymbolId& p_id)
    : m_mesh{std::dynamic_pointer_cast<const MeshType>(i_mesh)}
  {
    std::cout << "number of bc descr = " << bc_descriptor_list.size() << '\n';

    std::vector<BoundaryConditionHandler> bc_list;
    {
      constexpr ItemType FaceType = [] {
        if constexpr (Dimension > 1) {
          return ItemType::face;
        } else {
          return ItemType::node;
        }
      }();

      for (const auto& bc_descriptor : bc_descriptor_list) {
        switch (bc_descriptor->type()) {
        case IBoundaryConditionDescriptor::Type::symmetry: {
          const SymmetryBoundaryConditionDescriptor& sym_bc_descriptor =
            dynamic_cast<const SymmetryBoundaryConditionDescriptor&>(*bc_descriptor);
          for (size_t i_ref_face_list = 0;
               i_ref_face_list < m_mesh->connectivity().template numberOfRefItemList<FaceType>(); ++i_ref_face_list) {
            const auto& ref_face_list = m_mesh->connectivity().template refItemList<FaceType>(i_ref_face_list);
            const RefId& ref          = ref_face_list.refId();
            if (ref == sym_bc_descriptor.boundaryDescriptor()) {
              SymmetryBoundaryCondition<MeshType::Dimension>* sym_bc =
                new SymmetryBoundaryCondition<MeshType::Dimension>(
                  MeshFlatNodeBoundary<MeshType::Dimension>(m_mesh, ref_face_list));
              std::shared_ptr<SymmetryBoundaryCondition<MeshType::Dimension>> bc(sym_bc);
              bc_list.push_back(BoundaryConditionHandler(bc));
            }
          }
          break;
        }
        case IBoundaryConditionDescriptor::Type::velocity: {
          const VelocityBoundaryConditionDescriptor& velocity_bc_descriptor =
            dynamic_cast<const VelocityBoundaryConditionDescriptor&>(*bc_descriptor);
          for (size_t i_ref_face_list = 0;
               i_ref_face_list < m_mesh->connectivity().template numberOfRefItemList<FaceType>(); ++i_ref_face_list) {
            const auto& ref_face_list = m_mesh->connectivity().template refItemList<FaceType>(i_ref_face_list);
            const RefId& ref          = ref_face_list.refId();
            if (ref == velocity_bc_descriptor.boundaryDescriptor()) {
              const FunctionSymbolId velocity_id = velocity_bc_descriptor.functionSymbolId();

              if constexpr (Dimension == 1) {
                const auto& node_list = ref_face_list.list();

                Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
                  TinyVector<Dimension>)>::template interpolate<FaceType>(velocity_id, m_mesh->xr(), node_list);

                std::shared_ptr bc =
                  std::make_shared<VelocityBoundaryCondition<MeshType::Dimension>>(node_list, value_list);
                bc_list.push_back(BoundaryConditionHandler(bc));
              } else {
                const auto& face_list           = ref_face_list.list();
                const auto& face_to_node_matrix = m_mesh->connectivity().faceToNodeMatrix();
                std::set<NodeId> node_set;
                for (size_t i_face = 0; i_face < face_list.size(); ++i_face) {
                  FaceId face_id         = face_list[i_face];
                  const auto& face_nodes = face_to_node_matrix[face_id];
                  for (size_t i_node = 0; i_node < face_nodes.size(); ++i_node) {
                    node_set.insert(face_nodes[i_node]);
                  }
                }

                Array<NodeId> node_list = convert_to_array(node_set);

                Array<const TinyVector<Dimension>> value_list = InterpolateItemValue<TinyVector<Dimension>(
                  TinyVector<Dimension>)>::template interpolate<ItemType::node>(velocity_id, m_mesh->xr(), node_list);

                std::shared_ptr bc =
                  std::make_shared<VelocityBoundaryCondition<MeshType::Dimension>>(node_list, value_list);
                bc_list.push_back(BoundaryConditionHandler(bc));
              }
            }
          }
          break;
        }
        case IBoundaryConditionDescriptor::Type::pressure: {
          const PressureBoundaryConditionDescriptor& pressure_bc_descriptor =
            dynamic_cast<const PressureBoundaryConditionDescriptor&>(*bc_descriptor);
          for (size_t i_ref_face_list = 0;
               i_ref_face_list < m_mesh->connectivity().template numberOfRefItemList<FaceType>(); ++i_ref_face_list) {
            const auto& ref_face_list = m_mesh->connectivity().template refItemList<FaceType>(i_ref_face_list);
            const RefId& ref          = ref_face_list.refId();
            if (ref == pressure_bc_descriptor.boundaryDescriptor()) {
              const auto& face_list = ref_face_list.list();

              const FunctionSymbolId pressure_id = pressure_bc_descriptor.functionSymbolId();

              Array<const double> face_values = [&] {
                if constexpr (Dimension == 1) {
                  return InterpolateItemValue<double(
                    TinyVector<Dimension>)>::template interpolate<FaceType>(pressure_id, m_mesh->xr(), face_list);
                } else {
                  MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

                  return InterpolateItemValue<double(
                    TinyVector<Dimension>)>::template interpolate<FaceType>(pressure_id, mesh_data.xl(), face_list);
                }
              }();

              std::shared_ptr bc =
                std::make_shared<PressureBoundaryCondition<MeshType::Dimension>>(face_list, face_values);
              bc_list.push_back(BoundaryConditionHandler(bc));
            }
          }
          break;
        }
        default: {
          std::ostringstream error_msg;
          error_msg << *bc_descriptor << " is an invalid boundary condition for acoustic solver";
          throw NormalError(error_msg.str());
        }
        }
      }
    }

    UnknownsType unknowns(*m_mesh);

    {
      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

      unknowns.rhoj() =
        InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(rho_id,
                                                                                                  mesh_data.xj());

      unknowns.pj() =
        InterpolateItemValue<double(TinyVector<Dimension>)>::template interpolate<ItemType::cell>(p_id, mesh_data.xj());

      unknowns.uj() = InterpolateItemValue<TinyVector<Dimension>(
        TinyVector<Dimension>)>::template interpolate<ItemType::cell>(u_id, mesh_data.xj());
    }
    unknowns.gammaj().fill(1.4);

    AcousticSolver acoustic_solver(m_mesh, bc_list);

    const double tmax = 0.2;
    double t          = 0;

    int itermax   = std::numeric_limits<int>::max();
    int iteration = 0;

    CellValue<double>& rhoj              = unknowns.rhoj();
    CellValue<double>& ej                = unknowns.ej();
    CellValue<double>& pj                = unknowns.pj();
    CellValue<double>& gammaj            = unknowns.gammaj();
    CellValue<double>& cj                = unknowns.cj();
    CellValue<TinyVector<Dimension>>& uj = unknowns.uj();
    CellValue<double>& Ej                = unknowns.Ej();
    CellValue<double>& mj                = unknowns.mj();
    CellValue<double>& inv_mj            = unknowns.invMj();

    BlockPerfectGas block_eos(rhoj, ej, pj, gammaj, cj);
    block_eos.updateEandCFromRhoP();

    {
      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

      const CellValue<const double> Vj = mesh_data.Vj();

      parallel_for(
        m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { Ej[j] = ej[j] + 0.5 * (uj[j], uj[j]); });

      parallel_for(
        m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { mj[j] = rhoj[j] * Vj[j]; });

      parallel_for(
        m_mesh->numberOfCells(), PUGS_LAMBDA(CellId j) { inv_mj[j] = 1. / mj[j]; });
    }

    VTKWriter vtk_writer("mesh_" + std::to_string(Dimension), 0.01);

    while ((t < tmax) and (iteration < itermax)) {
      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

      vtk_writer.write(m_mesh,
                       {NamedItemValue{"density", rhoj}, NamedItemValue{"velocity", unknowns.uj()},
                        NamedItemValue{"coords", m_mesh->xr()}, NamedItemValue{"xj", mesh_data.xj()},
                        NamedItemValue{"cell_owner", m_mesh->connectivity().cellOwner()},
                        NamedItemValue{"node_owner", m_mesh->connectivity().nodeOwner()}},
                       t);
      double dt = 0.4 * acoustic_solver.acoustic_dt(mesh_data.Vj(), cj);
      if (t + dt > tmax) {
        dt = tmax - t;
      }

      std::cout.setf(std::cout.scientific);
      std::cout << "iteration " << rang::fg::cyan << std::setw(4) << iteration << rang::style::reset
                << " time=" << rang::fg::green << t << rang::style::reset << " dt=" << rang::fgB::blue << dt
                << rang::style::reset << '\n';

      m_mesh = acoustic_solver.computeNextStep(dt, unknowns);

      block_eos.updatePandCFromRhoE();

      t += dt;
      ++iteration;
    }
    std::cout << rang::style::bold << "Final time=" << rang::fgB::green << t << rang::style::reset << " reached after "
              << rang::fgB::cyan << iteration << rang::style::reset << rang::style::bold << " iterations"
              << rang::style::reset << '\n';
    {
      MeshDataType& mesh_data = MeshDataManager::instance().getMeshData(*m_mesh);

      vtk_writer.write(m_mesh,
                       {NamedItemValue{"density", rhoj}, NamedItemValue{"velocity", unknowns.uj()},
                        NamedItemValue{"coords", m_mesh->xr()}, NamedItemValue{"xj", mesh_data.xj()},
                        NamedItemValue{"cell_owner", m_mesh->connectivity().cellOwner()},
                        NamedItemValue{"node_owner", m_mesh->connectivity().nodeOwner()}},
                       t, true);   // forces last output
    }
  }
};

SchemeModule::SchemeModule()
{
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IBoundaryConditionDescriptor>>);

  this->_addBuiltinFunction("boundaryName",
                            std::make_shared<
                              BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryDescriptor>(const std::string&)>>(

                              [](const std::string& boundary_name) -> std::shared_ptr<const IBoundaryDescriptor> {
                                return std::make_shared<NamedBoundaryDescriptor>(boundary_name);
                              }

                              ));

  this->_addBuiltinFunction("boundaryTag",
                            std::make_shared<
                              BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryDescriptor>(int64_t)>>(

                              [](int64_t boundary_tag) -> std::shared_ptr<const IBoundaryDescriptor> {
                                return std::make_shared<NumberedBoundaryDescriptor>(boundary_tag);
                              }

                              ));

  this
    ->_addBuiltinFunction("symmetry",
                          std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryConditionDescriptor>(
                            std::shared_ptr<const IBoundaryDescriptor>)>>(

                            [](std::shared_ptr<const IBoundaryDescriptor> boundary)
                              -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                              return std::make_shared<SymmetryBoundaryConditionDescriptor>(boundary);
                            }

                            ));

  this->_addBuiltinFunction("pressure",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IBoundaryConditionDescriptor>(std::shared_ptr<const IBoundaryDescriptor>,
                                                                  const FunctionSymbolId&)>>(

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                 const FunctionSymbolId& pressure_id)
                                -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<PressureBoundaryConditionDescriptor>(boundary, pressure_id);
                              }

                              ));

  this->_addBuiltinFunction("velocity",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IBoundaryConditionDescriptor>(std::shared_ptr<const IBoundaryDescriptor>,
                                                                  const FunctionSymbolId&)>>(

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                 const FunctionSymbolId& velocity_id)
                                -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<VelocityBoundaryConditionDescriptor>(boundary, velocity_id);
                              }

                              ));

  this->_addBuiltinFunction("glace",
                            std::make_shared<BuiltinFunctionEmbedder<
                              void(std::shared_ptr<const IMesh>,
                                   const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&,
                                   const FunctionSymbolId&, const FunctionSymbolId&, const FunctionSymbolId&)>>(

                              [](std::shared_ptr<const IMesh> p_mesh,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list,
                                 const FunctionSymbolId& rho_id, const FunctionSymbolId& u_id,
                                 const FunctionSymbolId& p_id) -> void {
                                switch (p_mesh->dimension()) {
                                case 1: {
                                  GlaceScheme<1>{p_mesh, bc_descriptor_list, rho_id, u_id, p_id};
                                  break;
                                }
                                case 2: {
                                  GlaceScheme<2>{p_mesh, bc_descriptor_list, rho_id, u_id, p_id};
                                  break;
                                }
                                case 3: {
                                  GlaceScheme<3>{p_mesh, bc_descriptor_list, rho_id, u_id, p_id};
                                  break;
                                }
                                default: {
                                  throw UnexpectedError("invalid mesh dimension");
                                }
                                }
                              }

                              ));
}
