#include <language/modules/MeshModule.hpp>

#include <language/modules/ModuleRepository.hpp>

#include <algebra/TinyVector.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <language/utils/BinaryOperatorProcessorBuilder.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/CheckpointResumeRepository.hpp>
#include <language/utils/FunctionTable.hpp>
#include <language/utils/ItemArrayVariantFunctionInterpoler.hpp>
#include <language/utils/ItemValueVariantFunctionInterpoler.hpp>
#include <language/utils/OStream.hpp>
#include <language/utils/OperatorRepository.hpp>
#include <language/utils/SymbolTable.hpp>
#include <language/utils/TypeDescriptor.hpp>
#include <mesh/CartesianMeshBuilder.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/ConnectivityUtils.hpp>
#include <mesh/DualMeshManager.hpp>
#include <mesh/GmshReader.hpp>
#include <mesh/IBoundaryDescriptor.hpp>
#include <mesh/IZoneDescriptor.hpp>
#include <mesh/ItemArrayVariant.hpp>
#include <mesh/ItemValueVariant.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshRelaxer.hpp>
#include <mesh/MeshTransformer.hpp>
#include <mesh/MeshUtils.hpp>
#include <mesh/MeshVariant.hpp>
#include <mesh/NamedBoundaryDescriptor.hpp>
#include <mesh/NamedInterfaceDescriptor.hpp>
#include <mesh/NamedZoneDescriptor.hpp>
#include <mesh/NumberedBoundaryDescriptor.hpp>
#include <mesh/NumberedInterfaceDescriptor.hpp>
#include <mesh/NumberedZoneDescriptor.hpp>
#include <mesh/SubItemArrayPerItemVariant.hpp>
#include <mesh/SubItemValuePerItemVariant.hpp>
#include <utils/Exceptions.hpp>

#include <utils/checkpointing/ReadIBoundaryDescriptor.hpp>
#include <utils/checkpointing/ReadIInterfaceDescriptor.hpp>
#include <utils/checkpointing/ReadIZoneDescriptor.hpp>
#include <utils/checkpointing/ReadItemArrayVariant.hpp>
#include <utils/checkpointing/ReadItemType.hpp>
#include <utils/checkpointing/ReadItemValueVariant.hpp>
#include <utils/checkpointing/ReadMesh.hpp>
#include <utils/checkpointing/ReadSubItemArrayPerItemVariant.hpp>
#include <utils/checkpointing/ReadSubItemValuePerItemVariant.hpp>
#include <utils/checkpointing/WriteIBoundaryDescriptor.hpp>
#include <utils/checkpointing/WriteIInterfaceDescriptor.hpp>
#include <utils/checkpointing/WriteIZoneDescriptor.hpp>
#include <utils/checkpointing/WriteItemArrayVariant.hpp>
#include <utils/checkpointing/WriteItemType.hpp>
#include <utils/checkpointing/WriteItemValueVariant.hpp>
#include <utils/checkpointing/WriteMesh.hpp>
#include <utils/checkpointing/WriteSubItemArrayPerItemVariant.hpp>
#include <utils/checkpointing/WriteSubItemValuePerItemVariant.hpp>

MeshModule::MeshModule()
{
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const MeshVariant>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IInterfaceDescriptor>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IZoneDescriptor>>);

  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const ItemType>>);

  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const SubItemValuePerItemVariant>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const SubItemArrayPerItemVariant>>);

  this->_addBuiltinFunction("cell", std::function(

                                      []() -> std::shared_ptr<const ItemType> {
                                        return std::make_shared<ItemType>(ItemType::cell);
                                      }

                                      ));

  this->_addBuiltinFunction("face", std::function(

                                      []() -> std::shared_ptr<const ItemType> {
                                        return std::make_shared<ItemType>(ItemType::face);
                                      }

                                      ));

  this->_addBuiltinFunction("edge", std::function(

                                      []() -> std::shared_ptr<const ItemType> {
                                        return std::make_shared<ItemType>(ItemType::edge);
                                      }

                                      ));

  this->_addBuiltinFunction("node", std::function(

                                      []() -> std::shared_ptr<const ItemType> {
                                        return std::make_shared<ItemType>(ItemType::node);
                                      }

                                      ));

  this->_addBuiltinFunction("readGmsh", std::function(

                                          [](const std::string& file_name) -> std::shared_ptr<const MeshVariant> {
                                            GmshReader gmsh_reader(file_name);
                                            return gmsh_reader.mesh();
                                          }

                                          ));

  this->_addBuiltinFunction("boundaryName",
                            std::function(

                              [](const std::string& boundary_name) -> std::shared_ptr<const IBoundaryDescriptor> {
                                return std::make_shared<NamedBoundaryDescriptor>(boundary_name);
                              }

                              ));

  this->_addBuiltinFunction("boundaryTag", std::function(

                                             [](int64_t boundary_tag) -> std::shared_ptr<const IBoundaryDescriptor> {
                                               return std::make_shared<NumberedBoundaryDescriptor>(boundary_tag);
                                             }

                                             ));

  this->_addBuiltinFunction("interfaceName",
                            std::function(

                              [](const std::string& interface_name) -> std::shared_ptr<const IInterfaceDescriptor> {
                                return std::make_shared<NamedInterfaceDescriptor>(interface_name);
                              }

                              ));

  this->_addBuiltinFunction("interfaceTag", std::function(

                                              [](int64_t interface_tag) -> std::shared_ptr<const IInterfaceDescriptor> {
                                                return std::make_shared<NumberedInterfaceDescriptor>(interface_tag);
                                              }

                                              ));

  this->_addBuiltinFunction("zoneTag", std::function(

                                         [](int64_t zone_tag) -> std::shared_ptr<const IZoneDescriptor> {
                                           return std::make_shared<NumberedZoneDescriptor>(zone_tag);
                                         }

                                         ));

  this->_addBuiltinFunction("zoneName", std::function(

                                          [](const std::string& zone_name) -> std::shared_ptr<const IZoneDescriptor> {
                                            return std::make_shared<NamedZoneDescriptor>(zone_name);
                                          }

                                          ));

  this->_addBuiltinFunction("interpolate",
                            std::function(

                              [](std::shared_ptr<const MeshVariant> mesh_v, std::shared_ptr<const ItemType> item_type,
                                 const FunctionSymbolId& function_id) -> std::shared_ptr<const ItemValueVariant> {
                                return ItemValueVariantFunctionInterpoler{mesh_v, *item_type, function_id}
                                  .interpolate();
                              }

                              ));

  this->_addBuiltinFunction("interpolate_array",
                            std::function(

                              [](std::shared_ptr<const MeshVariant> mesh_v, std::shared_ptr<const ItemType> item_type,
                                 const std::vector<FunctionSymbolId>& function_id_list)
                                -> std::shared_ptr<const ItemArrayVariant> {
                                return ItemArrayVariantFunctionInterpoler{mesh_v, *item_type, function_id_list}
                                  .interpolate();
                              }

                              ));

  this->_addBuiltinFunction("transform",
                            std::function(

                              [](std::shared_ptr<const MeshVariant> mesh_v,
                                 const FunctionSymbolId& function_id) -> std::shared_ptr<const MeshVariant> {
                                return MeshTransformer{}.transform(function_id, mesh_v);
                              }

                              ));

  this->_addBuiltinFunction("relax", std::function(

                                       [](const std::shared_ptr<const MeshVariant>& source_mesh_v,
                                          const std::shared_ptr<const MeshVariant>& destination_mesh_v,
                                          const double& theta) -> std::shared_ptr<const MeshVariant> {
                                         return MeshRelaxer{}.relax(source_mesh_v, destination_mesh_v, theta);
                                       }

                                       ));

  this->_addBuiltinFunction("check_connectivity_ordering",
                            std::function(

                              [](const std::shared_ptr<const MeshVariant>& mesh_v) -> bool {
                                return checkConnectivityOrdering(mesh_v);
                              }

                              ));

  this->_addBuiltinFunction("cartesianMesh",
                            std::function(

                              [](const TinyVector<1> a, const TinyVector<1> b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const MeshVariant> {
                                constexpr uint64_t dimension = 1;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + stringify(dimension) + " dimensions, provided " +
                                                    stringify(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("cartesianMesh",
                            std::function(

                              [](const TinyVector<2> a, const TinyVector<2> b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const MeshVariant> {
                                constexpr uint64_t dimension = 2;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + stringify(dimension) + " dimensions, provided " +
                                                    stringify(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("cartesianMesh",
                            std::function(

                              [](const TinyVector<3>& a, const TinyVector<3>& b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const MeshVariant> {
                                constexpr uint64_t dimension = 3;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + stringify(dimension) + " dimensions, provided " +
                                                    stringify(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("diamondDual", std::function(

                                             [](const std::shared_ptr<const MeshVariant>& mesh_v)
                                               -> std::shared_ptr<const MeshVariant> {
                                               return DualMeshManager::instance().getDiamondDualMesh(mesh_v);
                                             }

                                             ));

  this->_addBuiltinFunction("medianDual", std::function(

                                            [](const std::shared_ptr<const MeshVariant>& mesh_v)
                                              -> std::shared_ptr<const MeshVariant> {
                                              return DualMeshManager::instance().getMedianDualMesh(mesh_v);
                                            }

                                            ));

  this->_addBuiltinFunction("cell_owner", std::function(

                                            [](const std::shared_ptr<const MeshVariant>& mesh_v)
                                              -> std::shared_ptr<const ItemValueVariant> {
                                              return std::visit(
                                                [&](auto&& mesh) {
                                                  const auto& connectivity = mesh->connectivity();
                                                  auto cell_owner          = connectivity.cellOwner();
                                                  CellValue<long int> cell_owner_long{connectivity};
                                                  parallel_for(
                                                    connectivity.numberOfCells(), PUGS_LAMBDA(const CellId cell_id) {
                                                      cell_owner_long[cell_id] = cell_owner[cell_id];
                                                    });
                                                  return std::make_shared<const ItemValueVariant>(cell_owner_long);
                                                },
                                                mesh_v->variant());
                                            }

                                            ));

  this->_addBuiltinFunction("node_owner", std::function(

                                            [](const std::shared_ptr<const MeshVariant>& mesh_v)
                                              -> std::shared_ptr<const ItemValueVariant> {
                                              return std::visit(
                                                [&](auto&& mesh) {
                                                  const auto& connectivity = mesh->connectivity();
                                                  auto node_owner          = connectivity.nodeOwner();
                                                  NodeValue<long int> node_owner_long{connectivity};
                                                  parallel_for(
                                                    connectivity.numberOfNodes(), PUGS_LAMBDA(const NodeId node_id) {
                                                      node_owner_long[node_id] = node_owner[node_id];
                                                    });
                                                  return std::make_shared<const ItemValueVariant>(node_owner_long);
                                                },
                                                mesh_v->variant());
                                            }

                                            ));
}

void
MeshModule::registerOperators() const
{
  OperatorRepository& repository = OperatorRepository::instance();

  repository.addBinaryOperator<language::shift_left_op>(
    std::make_shared<
      BinaryOperatorProcessorBuilder<language::shift_left_op, std::shared_ptr<const OStream>,
                                     std::shared_ptr<const OStream>, std::shared_ptr<const MeshVariant>>>());
}

void
MeshModule::registerCheckpointResume() const
{
#ifdef PUGS_HAS_HDF5
  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const MeshVariant>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeMesh(symbol_name, embedded_data, file, checkpoint_group,
                                                    symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readMesh(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeIBoundaryDescriptor(symbol_name, embedded_data, file, checkpoint_group,
                                                                   symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readIBoundaryDescriptor(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const IInterfaceDescriptor>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeIInterfaceDescriptor(symbol_name, embedded_data, file, checkpoint_group,
                                                                    symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readIInterfaceDescriptor(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const IZoneDescriptor>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeIZoneDescriptor(symbol_name, embedded_data, file, checkpoint_group,
                                                               symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readIZoneDescriptor(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const ItemType>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeItemType(symbol_name, embedded_data, file, checkpoint_group,
                                                        symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readItemType(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeItemArrayVariant(symbol_name, embedded_data, file, checkpoint_group,
                                                                symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readItemArrayVariant(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeItemValueVariant(symbol_name, embedded_data, file, checkpoint_group,
                                                                symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readItemValueVariant(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const SubItemArrayPerItemVariant>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeSubItemArrayPerItemVariant(symbol_name, embedded_data, file,
                                                                          checkpoint_group, symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readSubItemArrayPerItemVariant(symbol_name, symbol_table_group);
                         }));

  CheckpointResumeRepository::instance()
    .addCheckpointResume(ast_node_data_type_from<std::shared_ptr<const SubItemValuePerItemVariant>>,
                         std::function([](const std::string& symbol_name, const EmbeddedData& embedded_data,
                                          HighFive::File& file, HighFive::Group& checkpoint_group,
                                          HighFive::Group& symbol_table_group) {
                           checkpointing::writeSubItemValuePerItemVariant(symbol_name, embedded_data, file,
                                                                          checkpoint_group, symbol_table_group);
                         }),
                         std::function([](const std::string& symbol_name,
                                          const HighFive::Group& symbol_table_group) -> EmbeddedData {
                           return checkpointing::readSubItemValuePerItemVariant(symbol_name, symbol_table_group);
                         }));

#endif   // PUGS_HAS_HDF5
}

ModuleRepository::Subscribe<MeshModule> mesh_module;