#ifndef PUGS_FUNCTION_ADAPTER_HPP
#define PUGS_FUNCTION_ADAPTER_HPP

#include <language/ASTNode.hpp>
#include <language/ASTNodeDataType.hpp>
#include <language/SymbolTable.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/RuntimeError.hpp>
#include <utils/Array.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsMacros.hpp>

#include <Kokkos_Core.hpp>

#include <array>

template <typename T>
class PugsFunctionAdapter;
template <typename OutputType, typename... InputType>
class PugsFunctionAdapter<OutputType(InputType...)>
{
 protected:
  using FlattenList = std::array<int32_t, sizeof...(InputType)>;

 private:
  template <typename T>
  PUGS_INLINE static void
  _flattenArgT(const T&, ExecutionPolicy::Context&, size_t&)
  {
    throw UnexpectedError("cannot flatten type " + demangle<T>());
  }

  template <size_t N>
  PUGS_INLINE static void
  _flattenArgT(const TinyVector<N>& t, ExecutionPolicy::Context& context, size_t& i_context)
  {
    for (size_t i = 0; i < N; ++i) {
      context[i_context + i] = t[i];
    }
  }

  template <typename T, typename... Args>
  PUGS_INLINE static void
  _convertArgs(const T& t,
               const Args&&... args,
               ExecutionPolicy::Context& context,
               const FlattenList& flatten,
               size_t i_context)
  {
    if (flatten[sizeof...(args)]) {
      _flattenArgT(t, context, i_context);
    } else {
      context[i_context++] = t;
    }
    if constexpr (sizeof...(args) > 0) {
      _convertArgs(std::forward<Args>(args)..., context, flatten, i_context);
    }
  }

  template <typename Arg, typename... RemainingArgs>
  [[nodiscard]] PUGS_INLINE static bool
  _checkValidArgumentDataType(const ASTNode& input_expression, FlattenList& flatten_list) noexcept
  {
    constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>;
    const ASTNodeDataType& input_data_type                    = input_expression.m_data_type;

    constexpr size_t i_argument = sizeof...(InputType) - 1 - sizeof...(RemainingArgs);
    flatten_list[i_argument]    = false;

    if (not isNaturalConversion(expected_input_data_type, input_data_type)) {
      if ((expected_input_data_type == ASTNodeDataType::vector_t) and (input_data_type == ASTNodeDataType::list_t)) {
        flatten_list[i_argument] = true;
        if (expected_input_data_type.dimension() != input_expression.children.size()) {
          return false;
        } else {
          for (const auto& child : input_expression.children) {
            const ASTNodeDataType& data_type = child->m_data_type;
            if (not isNaturalConversion(ast_node_data_type_from<double>, data_type)) {
              return false;
            }
          }
        }
      } else {
        return false;
      }
    }
    if constexpr (sizeof...(RemainingArgs) == 0) {
      return true;
    } else {
      return false;
    }
  }

  [[nodiscard]] PUGS_INLINE static bool
  _checkValidInputDataType(const ASTNode& input_expression, FlattenList& flatten_list) noexcept
  {
    return _checkValidArgumentDataType<InputType...>(input_expression, flatten_list);
  }

  [[nodiscard]] PUGS_INLINE static bool
  _checkValidOutputDataType(const ASTNode& return_expression) noexcept
  {
    constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>;
    const ASTNodeDataType& return_data_type                    = return_expression.m_data_type;

    if (not isNaturalConversion(return_data_type, expected_return_data_type)) {
      if (expected_return_data_type == ASTNodeDataType::vector_t) {
        if (return_data_type == ASTNodeDataType::list_t) {
          if (expected_return_data_type.dimension() != return_expression.children.size()) {
            return false;
          } else {
            for (const auto& child : return_expression.children) {
              const ASTNodeDataType& data_type = child->m_data_type;
              if (not isNaturalConversion(data_type, ast_node_data_type_from<double>)) {
                return false;
              }
            }
          }
        }
      }
    }
    return true;
  }

  template <typename Arg, typename... RemainingArgs>
  [[nodiscard]] PUGS_INLINE static std::string
  _getCompoundTypeName()
  {
    if constexpr (sizeof...(RemainingArgs) > 0) {
      return dataTypeName(ast_node_data_type_from<Arg>) + _getCompoundTypeName<RemainingArgs...>();
    } else {
      return dataTypeName(ast_node_data_type_from<Arg>);
    }
  }

  [[nodiscard]] static std::string
  _getInputDataTypeName()
  {
    return _getCompoundTypeName<InputType...>();
  }

 protected:
  [[nodiscard]] PUGS_INLINE static FlattenList
  getFlattenArgs(const FunctionSymbolId& function_symbol_id)
  {
    auto& function = function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()];

    FlattenList flatten_list;

    bool has_valid_input  = _checkValidInputDataType(*function.definitionNode().children[0], flatten_list);
    bool has_valid_output = _checkValidOutputDataType(*function.definitionNode().children[1]);

    if (not(has_valid_input and has_valid_output)) {
      std::ostringstream error_message;
      error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow
                    << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>)
                    << rang::style::reset << '\n'
                    << "note: provided function " << rang::fgB::magenta << function.name() << ": "
                    << function.domainMappingNode().string() << rang::style::reset << std::ends;
      throw RuntimeError(error_message.str());
    }

    return flatten_list;
  }

  [[nodiscard]] PUGS_INLINE static auto&
  getFunctionExpression(const FunctionSymbolId& function_symbol_id)
  {
    auto& function = function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()];

    return *function.definitionNode().children[1];
  }

  [[nodiscard]] PUGS_INLINE static auto
  getContextList(const ASTNode& expression)
  {
    Array<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
    auto& context = expression.m_symbol_table->context();

    for (size_t i = 0; i < context_list.size(); ++i) {
      context_list[i] =
        ExecutionPolicy(ExecutionPolicy{},
                        {context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())});
    }

    return context_list;
  }

  template <typename... Args>
  PUGS_INLINE static void
  convertArgs(ExecutionPolicy::Context& context, const FlattenList& flatten, const Args&... args)
  {
    static_assert(std::is_same_v<std::tuple<InputType...>, std::tuple<Args...>>, "unexpected input type");
    _convertArgs(args..., context, flatten, 0);
  }

  [[nodiscard]] PUGS_INLINE static std::function<OutputType(DataVariant&& result)>
  getResultConverter(const ASTNodeDataType& data_type)
  {
    switch (data_type) {
    case ASTNodeDataType::list_t: {
      return [](DataVariant&& result) -> OutputType {
        AggregateDataVariant& v = std::get<AggregateDataVariant>(result);
        OutputType x;
        for (size_t i = 0; i < x.dimension(); ++i) {
          x[i] = std::get<double>(v[i]);
        }
        return x;
      };
    }
    case ASTNodeDataType::vector_t: {
      return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
    }
    case ASTNodeDataType::double_t: {
      if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
        return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
      } else {
        throw UnexpectedError("unexpected data_type");
      }
    }
    default: {
      throw UnexpectedError("unexpected data_type");
    }
    }
  }
};

#endif   // PUGS_FUNCTION_ADAPTER_HPP
