#include <language/modules/MeshModule.hpp>

#include <language/BuiltinFunctionEmbedder.hpp>
#include <language/FunctionTable.hpp>
#include <language/PugsFunctionAdapter.hpp>
#include <language/SymbolTable.hpp>
#include <language/TypeDescriptor.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/GmshReader.hpp>
#include <mesh/Mesh.hpp>
#include <utils/Exceptions.hpp>

#include <Kokkos_Core.hpp>

#include <array>
#include <cstdio>

template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<IMesh>> = {ASTNodeDataType::type_id_t, "mesh"};

template <typename T>
class MeshTransformation;
template <typename OutputType, typename... InputType>
class MeshTransformation<OutputType(InputType...)> : public PugsFunctionAdapter<OutputType(InputType...)>
{
  static constexpr size_t Dimension = OutputType::Dimension;
  using Adapter                     = PugsFunctionAdapter<OutputType(InputType...)>;

 public:
  static inline std::shared_ptr<Mesh<Connectivity<Dimension>>>
  transform(const FunctionSymbolId& function_symbol_id, std::shared_ptr<const IMesh> p_mesh)
  {
    using MeshType             = Mesh<Connectivity<Dimension>>;
    const MeshType& given_mesh = dynamic_cast<const MeshType&>(*p_mesh);

    const auto flatten_args = Adapter::getFlattenArgs(function_symbol_id);

    auto& expression    = Adapter::getFunctionExpression(function_symbol_id);
    auto convert_result = Adapter::getResultConverter(expression.m_data_type);

    Array<ExecutionPolicy> context_list = Adapter::getContextList(expression);

    NodeValue<const OutputType> given_xr = given_mesh.xr();
    NodeValue<OutputType> xr(given_mesh.connectivity());

    using execution_space = typename Kokkos::DefaultExecutionSpace::execution_space;
    Kokkos::Experimental::UniqueToken<execution_space, Kokkos::Experimental::UniqueTokenScope::Global> tokens;

    parallel_for(given_mesh.numberOfNodes(), [=, &expression, &flatten_args, &tokens](NodeId r) {
      const int32_t t = tokens.acquire();

      auto& execution_policy = context_list[t];

      Adapter::convertArgs(execution_policy.currentContext(), flatten_args, given_xr[r]);
      auto result = expression.execute(execution_policy);
      xr[r]       = convert_result(std::move(result));

      tokens.release(t);
    });

    return std::make_shared<MeshType>(given_mesh.shared_connectivity(), xr);
  }
};

MeshModule::MeshModule()
{
  this->_addTypeDescriptor(
    std::make_shared<TypeDescriptor>(ast_node_data_type_from<std::shared_ptr<IMesh>>.typeName()));

  this->_addBuiltinFunction("readGmsh", std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::string>>(
                                          std::function<std::shared_ptr<IMesh>(std::string)>{

                                            [](const std::string& file_name) -> std::shared_ptr<IMesh> {
                                              GmshReader gmsh_reader(file_name);
                                              return gmsh_reader.mesh();
                                            }}

                                          ));

  this->_addBuiltinFunction("transform",
                            std::make_shared<BuiltinFunctionEmbedder<std::shared_ptr<IMesh>, std::shared_ptr<IMesh>,
                                                                     FunctionSymbolId>>(
                              std::function<std::shared_ptr<IMesh>(std::shared_ptr<IMesh>, FunctionSymbolId)>{
                                [](std::shared_ptr<IMesh> p_mesh,
                                   const FunctionSymbolId& function_id) -> std::shared_ptr<IMesh> {
                                  switch (p_mesh->dimension()) {
                                  case 1: {
                                    using TransformT = TinyVector<1>(TinyVector<1>);
                                    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                  }
                                  case 2: {
                                    using TransformT = TinyVector<2>(TinyVector<2>);
                                    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                  }
                                  case 3: {
                                    using TransformT = TinyVector<3>(TinyVector<3>);
                                    return MeshTransformation<TransformT>::transform(function_id, p_mesh);
                                  }
                                  default: {
                                    throw UnexpectedError("invalid mesh dimension");
                                  }
                                  }
                                }}

                              ));
}