#include <language/modules/MeshModule.hpp>

#include <algebra/TinyVector.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <language/utils/BinaryOperatorProcessorBuilder.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/FunctionTable.hpp>
#include <language/utils/OStream.hpp>
#include <language/utils/OperatorRepository.hpp>
#include <language/utils/SymbolTable.hpp>
#include <language/utils/TypeDescriptor.hpp>
#include <mesh/CartesianMeshBuilder.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/DualMeshManager.hpp>
#include <mesh/GmshReader.hpp>
#include <mesh/IBoundaryDescriptor.hpp>
#include <mesh/IZoneDescriptor.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshRelaxer.hpp>
#include <mesh/MeshTransformer.hpp>
#include <mesh/NamedBoundaryDescriptor.hpp>
#include <mesh/NamedZoneDescriptor.hpp>
#include <mesh/NumberedBoundaryDescriptor.hpp>
#include <mesh/NumberedZoneDescriptor.hpp>
#include <utils/Exceptions.hpp>

#include <Kokkos_Core.hpp>

MeshModule::MeshModule()
{
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IMesh>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IBoundaryDescriptor>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IZoneDescriptor>>);

  this->_addBuiltinFunction("readGmsh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const IMesh>(const std::string&)>>(

                              [](const std::string& file_name) -> std::shared_ptr<const IMesh> {
                                GmshReader gmsh_reader(file_name);
                                return gmsh_reader.mesh();
                              }

                              ));

  this->_addBuiltinFunction("boundaryName",
                            std::make_shared<
                              BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryDescriptor>(const std::string&)>>(

                              [](const std::string& boundary_name) -> std::shared_ptr<const IBoundaryDescriptor> {
                                return std::make_shared<NamedBoundaryDescriptor>(boundary_name);
                              }

                              ));

  this->_addBuiltinFunction("zoneTag",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const IZoneDescriptor>(int64_t)>>(

                              [](int64_t zone_tag) -> std::shared_ptr<const IZoneDescriptor> {
                                return std::make_shared<NumberedZoneDescriptor>(zone_tag);
                              }

                              ));

  this->_addBuiltinFunction("zoneName",
                            std::make_shared<
                              BuiltinFunctionEmbedder<std::shared_ptr<const IZoneDescriptor>(const std::string&)>>(

                              [](const std::string& zone_name) -> std::shared_ptr<const IZoneDescriptor> {
                                return std::make_shared<NamedZoneDescriptor>(zone_name);
                              }

                              ));

  this->_addBuiltinFunction("boundaryTag",
                            std::make_shared<
                              BuiltinFunctionEmbedder<std::shared_ptr<const IBoundaryDescriptor>(int64_t)>>(

                              [](int64_t boundary_tag) -> std::shared_ptr<const IBoundaryDescriptor> {
                                return std::make_shared<NumberedBoundaryDescriptor>(boundary_tag);
                              }

                              ));

  this->_addBuiltinFunction("transform",
                            std::make_shared<BuiltinFunctionEmbedder<
                              std::shared_ptr<const IMesh>(std::shared_ptr<const IMesh>, const FunctionSymbolId&)>>(

                              [](std::shared_ptr<const IMesh> p_mesh,
                                 const FunctionSymbolId& function_id) -> std::shared_ptr<const IMesh> {
                                return MeshTransformer{}.transform(function_id, p_mesh);
                              }

                              ));

  this->_addBuiltinFunction("relax",
                            std::make_shared<BuiltinFunctionEmbedder<
                              std::shared_ptr<const IMesh>(const std::shared_ptr<const IMesh>&,
                                                           const std::shared_ptr<const IMesh>&, const double&)>>(

                              [](const std::shared_ptr<const IMesh>& source_mesh,
                                 const std::shared_ptr<const IMesh>& destination_mesh,
                                 const double& theta) -> std::shared_ptr<const IMesh> {
                                return MeshRelaxer{}.relax(source_mesh, destination_mesh, theta);
                              }

                              ));

  this->_addBuiltinFunction("cartesianMesh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IMesh>(const TinyVector<1>, const TinyVector<1>, const std::vector<uint64_t>&)>>(

                              [](const TinyVector<1> a, const TinyVector<1> b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const IMesh> {
                                constexpr uint64_t dimension = 1;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + stringify(dimension) + " dimensions, provided " +
                                                    stringify(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("cartesianMesh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IMesh>(const TinyVector<2>, const TinyVector<2>, const std::vector<uint64_t>&)>>(

                              [](const TinyVector<2> a, const TinyVector<2> b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const IMesh> {
                                constexpr uint64_t dimension = 2;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + stringify(dimension) + " dimensions, provided " +
                                                    stringify(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("cartesianMesh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IMesh>(const TinyVector<3>&, const TinyVector<3>&, const std::vector<uint64_t>&)>>(

                              [](const TinyVector<3>& a, const TinyVector<3>& b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const IMesh> {
                                constexpr uint64_t dimension = 3;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + stringify(dimension) + " dimensions, provided " +
                                                    stringify(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("diamondDual",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const IMesh>(
                              const std::shared_ptr<const IMesh>&)>>(

                              [](const std::shared_ptr<const IMesh>& i_mesh) -> std::shared_ptr<const IMesh> {
                                switch (i_mesh->dimension()) {
                                case 1: {
                                  using MeshType = Mesh<Connectivity<1>>;

                                  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
                                  return DualMeshManager::instance().getDiamondDualMesh(*p_mesh);
                                }
                                case 2: {
                                  using MeshType = Mesh<Connectivity<2>>;

                                  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
                                  return DualMeshManager::instance().getDiamondDualMesh(*p_mesh);
                                }
                                case 3: {
                                  using MeshType = Mesh<Connectivity<3>>;

                                  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
                                  return DualMeshManager::instance().getDiamondDualMesh(*p_mesh);
                                }
                                default: {
                                  throw UnexpectedError("invalid dimension");
                                }
                                }
                              }

                              ));

  this->_addBuiltinFunction("medianDual",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const IMesh>(
                              const std::shared_ptr<const IMesh>&)>>(

                              [](const std::shared_ptr<const IMesh>& i_mesh) -> std::shared_ptr<const IMesh> {
                                switch (i_mesh->dimension()) {
                                case 1: {
                                  using MeshType = Mesh<Connectivity<1>>;

                                  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
                                  return DualMeshManager::instance().getMedianDualMesh(*p_mesh);
                                }
                                case 2: {
                                  using MeshType = Mesh<Connectivity<2>>;

                                  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
                                  return DualMeshManager::instance().getMedianDualMesh(*p_mesh);
                                }
                                case 3: {
                                  using MeshType = Mesh<Connectivity<3>>;

                                  std::shared_ptr p_mesh = std::dynamic_pointer_cast<const MeshType>(i_mesh);
                                  return DualMeshManager::instance().getMedianDualMesh(*p_mesh);
                                }
                                default: {
                                  throw UnexpectedError("invalid dimension");
                                }
                                }
                              }

                              ));
}

void
MeshModule::registerOperators() const
{
  OperatorRepository& repository = OperatorRepository::instance();

  repository.addBinaryOperator<language::shift_left_op>(
    std::make_shared<BinaryOperatorProcessorBuilder<language::shift_left_op, std::shared_ptr<const OStream>,
                                                    std::shared_ptr<const OStream>, std::shared_ptr<const IMesh>>>());
}