#ifndef PUGS_FUNCTION_ADAPTER_HPP
#define PUGS_FUNCTION_ADAPTER_HPP

#include <language/ast/ASTNode.hpp>
#include <language/ast/ASTNodeDataType.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/Array.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsMacros.hpp>

#include <Kokkos_Core.hpp>

#include <array>

template <typename T>
class PugsFunctionAdapter;
template <typename OutputType, typename... InputType>
class PugsFunctionAdapter<OutputType(InputType...)>
{
 protected:
  using InputTuple              = std::tuple<std::decay_t<InputType>...>;
  constexpr static size_t NArgs = std::tuple_size_v<InputTuple>;

 private:
  template <typename T, typename... Args>
  PUGS_INLINE static void
  _convertArgs(ExecutionPolicy::Context& context, size_t i_context, const T& t, Args&&... args)
  {
    context[i_context++] = t;
    if constexpr (sizeof...(Args) > 0) {
      _convertArgs(context, i_context, std::forward<Args>(args)...);
    }
  }

  template <size_t I>
  [[nodiscard]] PUGS_INLINE static bool
  _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT)
  {
    using Arg = std::tuple_element_t<I, InputTuple>;

    constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>;

    Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t);
    const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType();

    return isNaturalConversion(expected_input_data_type, arg_data_type);
  }

  template <size_t... I>
  [[nodiscard]] PUGS_INLINE static bool
  _checkAllInputDataType(const ASTNode& input_expression, std::index_sequence<I...>)
  {
    Assert(NArgs == input_expression.children.size());
    return (_checkValidArgumentDataType<I>(*input_expression.children[I]) and ...);
  }

  [[nodiscard]] PUGS_INLINE static bool
  _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept
  {
    if constexpr (NArgs == 1) {
      return _checkValidArgumentDataType<0>(input_domain_expression);
    } else {
      if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or
          (input_domain_expression.children.size() != NArgs)) {
        return false;
      }

      using IndexSequence = std::make_index_sequence<NArgs>;
      return _checkAllInputDataType(input_domain_expression, IndexSequence{});
    }
  }

  [[nodiscard]] PUGS_INLINE static bool
  _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT)
  {
    constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>;
    const ASTNodeDataType& return_data_type                    = output_domain_expression.m_data_type.contentType();

    return isNaturalConversion(return_data_type, expected_return_data_type);
  }

  template <typename Arg, typename... RemainingArgs>
  [[nodiscard]] PUGS_INLINE static std::string
  _getCompoundTypeName()
  {
    if constexpr (sizeof...(RemainingArgs) > 0) {
      return dataTypeName(ast_node_data_type_from<Arg>) + '*' + _getCompoundTypeName<RemainingArgs...>();
    } else {
      return dataTypeName(ast_node_data_type_from<Arg>);
    }
  }

  [[nodiscard]] static std::string
  _getInputDataTypeName()
  {
    return _getCompoundTypeName<InputType...>();
  }

  PUGS_INLINE static void
  _checkFunction(const FunctionDescriptor& function)
  {
    bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]);
    bool has_valid_output       = _checkValidOutputDomain(*function.domainMappingNode().children[1]);

    if (not(has_valid_input_domain and has_valid_output)) {
      std::ostringstream error_message;
      error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow
                    << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>)
                    << rang::style::reset << '\n'
                    << "note: provided function " << rang::fgB::magenta << function.name() << ": "
                    << function.domainMappingNode().string() << rang::style::reset << std::ends;
      throw NormalError(error_message.str());
    }
  }

 protected:
  [[nodiscard]] PUGS_INLINE static auto&
  getFunctionExpression(const FunctionSymbolId& function_symbol_id)
  {
    auto& function = function_symbol_id.symbolTable().functionTable()[function_symbol_id.id()];
    _checkFunction(function);

    return *function.definitionNode().children[1];
  }

  [[nodiscard]] PUGS_INLINE static auto
  getContextList(const ASTNode& expression)
  {
    Array<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
    auto& context = expression.m_symbol_table->context();

    for (size_t i = 0; i < context_list.size(); ++i) {
      context_list[i] =
        ExecutionPolicy(ExecutionPolicy{},
                        {context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())});
    }

    return context_list;
  }

  template <typename... Args>
  PUGS_INLINE static void
  convertArgs(ExecutionPolicy::Context& context, Args&&... args)
  {
    static_assert(std::is_same_v<std::tuple<std::decay_t<InputType>...>, std::tuple<std::decay_t<Args>...>>,
                  "unexpected input type");
    _convertArgs(context, 0, args...);
  }

  [[nodiscard]] PUGS_INLINE static std::function<OutputType(DataVariant&& result)>
  getResultConverter(const ASTNodeDataType& data_type)
  {
    if constexpr (is_tiny_vector_v<OutputType>) {
      switch (data_type) {
      case ASTNodeDataType::list_t: {
        return [](DataVariant&& result) -> OutputType {
          AggregateDataVariant& v = std::get<AggregateDataVariant>(result);
          OutputType x;

          for (size_t i = 0; i < x.dimension(); ++i) {
            std::visit(
              [&](auto&& vi) {
                using Vi_T = std::decay_t<decltype(vi)>;
                if constexpr (std::is_arithmetic_v<Vi_T>) {
                  x[i] = vi;
                } else {
                  // LCOV_EXCL_START
                  throw UnexpectedError("expecting arithmetic value");
                  // LCOV_EXCL_STOP
                }
              },
              v[i]);
          }
          return x;
        };
      }
      case ASTNodeDataType::vector_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
      }
      case ASTNodeDataType::bool_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return
            [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::unsigned_int_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return [](DataVariant&& result) -> OutputType {
            return OutputType(static_cast<double>(std::get<uint64_t>(result)));
          };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::int_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return [](DataVariant&& result) -> OutputType {
            return OutputType{static_cast<double>(std::get<int64_t>(result))};
          };
        } else {
          // If this point is reached must be a 0 vector
          return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; };
        }
      }
      case ASTNodeDataType::double_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                              dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
      }
        // LCOV_EXCL_STOP
      }
    } else if constexpr (std::is_arithmetic_v<OutputType>) {
      switch (data_type) {
      case ASTNodeDataType::bool_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<bool>(result); };
      }
      case ASTNodeDataType::unsigned_int_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<uint64_t>(result); };
      }
      case ASTNodeDataType::int_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<int64_t>(result); };
      }
      case ASTNodeDataType::double_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<double>(result); };
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                              dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
      }
        // LCOV_EXCL_STOP
      }
    } else {
      static_assert(std::is_arithmetic_v<OutputType>, "unexpected output type");
    }
  }

  PugsFunctionAdapter() = delete;
};

#endif   // PUGS_FUNCTION_ADAPTER_HPP