#include <language/modules/MeshModule.hpp>

#include <algebra/TinyVector.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/FunctionTable.hpp>
#include <language/utils/PugsFunctionAdapter.hpp>
#include <language/utils/SymbolTable.hpp>
#include <language/utils/TypeDescriptor.hpp>
#include <mesh/CartesianMeshBuilder.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/GmshReader.hpp>
#include <mesh/Mesh.hpp>
#include <utils/Exceptions.hpp>

#include <Kokkos_Core.hpp>

template <typename T>
class MeshTransformation;
template <typename OutputType, typename... InputType>
class MeshTransformation<OutputType(InputType...)> : public PugsFunctionAdapter<OutputType(InputType...)>
{
  static constexpr size_t Dimension = OutputType::Dimension;
  using Adapter                     = PugsFunctionAdapter<OutputType(InputType...)>;

 public:
  static inline std::shared_ptr<Mesh<Connectivity<Dimension>>>
  transform(const FunctionSymbolId& function_symbol_id, std::shared_ptr<const IMesh> p_mesh)
  {
    using MeshType             = Mesh<Connectivity<Dimension>>;
    const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh);

    auto& expression    = Adapter::getFunctionExpression(function_symbol_id);
    auto convert_result = Adapter::getResultConverter(expression.m_data_type);

    Array<ExecutionPolicy> context_list = Adapter::getContextList(expression);

    NodeValue<const OutputType> given_xr = given_mesh.xr();
    NodeValue<OutputType> xr(given_mesh.connectivity());

    using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
    Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;

    parallel_for(given_mesh.numberOfNodes(), [=, &expression, &tokens](NodeId r) {
      const int32_t t = tokens.acquire();

      auto& execution_policy = context_list[t];

      Adapter::convertArgs(execution_policy.currentContext(), given_xr[r]);
      auto result = expression.execute(execution_policy);
      xr[r]       = convert_result(std::move(result));

      tokens.release(t);
    });

    return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
  }
};

MeshModule::MeshModule()
{
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IMesh>>);

  this->_addBuiltinFunction("readGmsh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<const IMesh>(const std::string&)>>(

                              [](const std::string& file_name) -> std::shared_ptr<const IMesh> {
                                GmshReader gmsh_reader(file_name);
                                return gmsh_reader.mesh();
                              }

                              ));

  this->_addBuiltinFunction("transform",
                            std::make_shared<BuiltinFunctionEmbedder<
                              std::shared_ptr<const IMesh>(std::shared_ptr<const IMesh>, const FunctionSymbolId&)>>(

                              [](std::shared_ptr<const IMesh> p_mesh,
                                 const FunctionSymbolId& function_id) -> std::shared_ptr<const IMesh> {
                                switch (p_mesh->dimension()) {
                                case 1: {
                                  using TransformT = TinyVector<1>(TinyVector<1>);
                                  return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                }
                                case 2: {
                                  using TransformT = TinyVector<2>(TinyVector<2>);
                                  return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                }
                                case 3: {
                                  using TransformT = TinyVector<3>(TinyVector<3>);
                                  return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                }
                                default: {
                                  throw UnexpectedError("invalid mesh dimension");
                                }
                                }
                              }

                              ));

  this->_addBuiltinFunction("cartesian1dMesh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IMesh>(const TinyVector<1>, const TinyVector<1>, const std::vector<uint64_t>&)>>(

                              [](const TinyVector<1> a, const TinyVector<1> b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const IMesh> {
                                constexpr uint64_t dimension = 1;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + std::to_string(dimension) +
                                                    " dimensions, provided " + std::to_string(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("cartesian2dMesh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IMesh>(const TinyVector<2>, const TinyVector<2>, const std::vector<uint64_t>&)>>(

                              [](const TinyVector<2> a, const TinyVector<2> b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const IMesh> {
                                constexpr uint64_t dimension = 2;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + std::to_string(dimension) +
                                                    " dimensions, provided " + std::to_string(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));

  this->_addBuiltinFunction("cartesian3dMesh",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<
                              const IMesh>(const TinyVector<3>&, const TinyVector<3>&, const std::vector<uint64_t>&)>>(

                              [](const TinyVector<3>& a, const TinyVector<3>& b,
                                 const std::vector<uint64_t>& box_sizes) -> std::shared_ptr<const IMesh> {
                                constexpr uint64_t dimension = 3;

                                if (box_sizes.size() != dimension) {
                                  throw NormalError("expecting " + std::to_string(dimension) +
                                                    " dimensions, provided " + std::to_string(box_sizes.size()));
                                }

                                const TinyVector<dimension, uint64_t> sizes = [&]() {
                                  TinyVector<dimension, uint64_t> s;
                                  for (size_t i = 0; i < dimension; ++i) {
                                    s[i] = box_sizes[i];
                                  }
                                  return s;
                                }();

                                CartesianMeshBuilder builder{a, b, sizes};
                                return builder.mesh();
                              }

                              ));
}
