#ifndef PUGS_FUNCTION_ADAPTER_HPP
#define PUGS_FUNCTION_ADAPTER_HPP

#include <language/ast/ASTNode.hpp>
#include <language/node_processor/ExecutionPolicy.hpp>
#include <language/utils/ASTNodeDataType.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsMacros.hpp>
#include <utils/SmallArray.hpp>

#include <Kokkos_Core.hpp>

#include <array>

template <typename T>
class PugsFunctionAdapter;
template <typename OutputType, typename... InputType>
class PugsFunctionAdapter<OutputType(InputType...)>
{
 protected:
  using InputTuple              = std::tuple<std::decay_t<InputType>...>;
  constexpr static size_t NArgs = std::tuple_size_v<InputTuple>;

 private:
  template <typename T, typename... Args>
  PUGS_INLINE static void
  _convertArgs(ExecutionPolicy::Context& context, size_t i_context, const T& t, Args&&... args)
  {
    context[i_context++] = t;
    if constexpr (sizeof...(Args) > 0) {
      _convertArgs(context, i_context, std::forward<Args>(args)...);
    }
  }

  template <size_t I>
  [[nodiscard]] PUGS_INLINE static bool
  _checkValidArgumentDataType(const ASTNode& arg_expression) noexcept(NO_ASSERT)
  {
    using Arg = std::tuple_element_t<I, InputTuple>;

    constexpr const ASTNodeDataType& expected_input_data_type = ast_node_data_type_from<Arg>;

    Assert(arg_expression.m_data_type == ASTNodeDataType::typename_t);
    const ASTNodeDataType& arg_data_type = arg_expression.m_data_type.contentType();

    return isNaturalConversion(expected_input_data_type, arg_data_type);
  }

  template <size_t... I>
  [[nodiscard]] PUGS_INLINE static bool
  _checkAllInputDataType(const ASTNode& input_expression, std::index_sequence<I...>)
  {
    Assert(NArgs == input_expression.children.size());
    return (_checkValidArgumentDataType<I>(*input_expression.children[I]) and ...);
  }

  [[nodiscard]] PUGS_INLINE static bool
  _checkValidInputDomain(const ASTNode& input_domain_expression) noexcept
  {
    if constexpr (NArgs == 1) {
      return _checkValidArgumentDataType<0>(input_domain_expression);
    } else {
      if ((input_domain_expression.m_data_type.contentType() != ASTNodeDataType::list_t) or
          (input_domain_expression.children.size() != NArgs)) {
        return false;
      }

      using IndexSequence = std::make_index_sequence<NArgs>;
      return _checkAllInputDataType(input_domain_expression, IndexSequence{});
    }
  }

  [[nodiscard]] PUGS_INLINE static bool
  _checkValidOutputDomain(const ASTNode& output_domain_expression) noexcept(NO_ASSERT)
  {
    constexpr const ASTNodeDataType& expected_return_data_type = ast_node_data_type_from<OutputType>;
    const ASTNodeDataType& return_data_type                    = output_domain_expression.m_data_type.contentType();

    return isNaturalConversion(return_data_type, expected_return_data_type);
  }

  template <typename Arg, typename... RemainingArgs>
  [[nodiscard]] PUGS_INLINE static std::string
  _getCompoundTypeName()
  {
    if constexpr (sizeof...(RemainingArgs) > 0) {
      return dataTypeName(ast_node_data_type_from<Arg>) + '*' + _getCompoundTypeName<RemainingArgs...>();
    } else {
      return dataTypeName(ast_node_data_type_from<Arg>);
    }
  }

  [[nodiscard]] static std::string
  _getInputDataTypeName()
  {
    return _getCompoundTypeName<InputType...>();
  }

  PUGS_INLINE static void
  _checkFunction(const FunctionDescriptor& function)
  {
    bool has_valid_input_domain = _checkValidInputDomain(*function.domainMappingNode().children[0]);
    bool has_valid_output       = _checkValidOutputDomain(*function.domainMappingNode().children[1]);

    if (not(has_valid_input_domain and has_valid_output)) {
      std::ostringstream error_message;
      error_message << "invalid function type" << rang::style::reset << "\nnote: expecting " << rang::fgB::yellow
                    << _getInputDataTypeName() << " -> " << dataTypeName(ast_node_data_type_from<OutputType>)
                    << rang::style::reset << '\n'
                    << "note: provided function " << rang::fgB::magenta << function.name() << ": "
                    << function.domainMappingNode().string() << rang::style::reset;
      throw NormalError(error_message.str());
    }
  }

 protected:
  [[nodiscard]] PUGS_INLINE static auto&
  getFunctionExpression(const FunctionSymbolId& function_symbol_id)
  {
    auto& function_descriptor = function_symbol_id.descriptor();
    _checkFunction(function_descriptor);

    return *function_descriptor.definitionNode().children[1];
  }

  [[nodiscard]] PUGS_INLINE static auto
  getContextList(const ASTNode& expression)
  {
    SmallArray<ExecutionPolicy> context_list(Kokkos::DefaultExecutionSpace::impl_thread_pool_size());
    auto& context = expression.m_symbol_table->context();

    for (size_t i = 0; i < context_list.size(); ++i) {
      context_list[i] =
        ExecutionPolicy(ExecutionPolicy{},
                        {context.id(), std::make_shared<ExecutionPolicy::Context::Values>(context.size())});
    }

    return context_list;
  }

  template <typename... Args>
  PUGS_INLINE static void
  convertArgs(ExecutionPolicy::Context& context, Args&&... args)
  {
    static_assert(std::is_same_v<std::tuple<std::decay_t<InputType>...>, std::tuple<std::decay_t<Args>...>>,
                  "unexpected input type");
    _convertArgs(context, 0, args...);
  }

  [[nodiscard]] PUGS_INLINE static std::function<OutputType(DataVariant&& result)>
  getResultConverter(const ASTNodeDataType& data_type)
  {
    if constexpr (is_tiny_vector_v<OutputType>) {
      switch (data_type) {
      case ASTNodeDataType::vector_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
      }
      case ASTNodeDataType::bool_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return
            [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::unsigned_int_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return [](DataVariant&& result) -> OutputType {
            return OutputType(static_cast<double>(std::get<uint64_t>(result)));
          };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::int_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return [](DataVariant&& result) -> OutputType {
            return OutputType{static_cast<double>(std::get<int64_t>(result))};
          };
        } else {
          // If this point is reached must be a 0 vector
          return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; };
        }
      }
      case ASTNodeDataType::double_t: {
        if constexpr (std::is_same_v<OutputType, TinyVector<1>>) {
          return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                              dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
      }
        // LCOV_EXCL_STOP
      }
    } else if constexpr (is_tiny_matrix_v<OutputType>) {
      switch (data_type) {
      case ASTNodeDataType::matrix_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<OutputType>(result); };
      }
      case ASTNodeDataType::bool_t: {
        if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
          return
            [](DataVariant&& result) -> OutputType { return OutputType{static_cast<double>(std::get<bool>(result))}; };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::unsigned_int_t: {
        if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
          return [](DataVariant&& result) -> OutputType {
            return OutputType(static_cast<double>(std::get<uint64_t>(result)));
          };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::int_t: {
        if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
          return [](DataVariant&& result) -> OutputType {
            return OutputType{static_cast<double>(std::get<int64_t>(result))};
          };
        } else {
          // If this point is reached must be a 0 matrix
          return [](DataVariant &&) -> OutputType { return OutputType{ZeroType{}}; };
        }
      }
      case ASTNodeDataType::double_t: {
        if constexpr (std::is_same_v<OutputType, TinyMatrix<1>>) {
          return [](DataVariant&& result) -> OutputType { return OutputType{std::get<double>(result)}; };
        } else {
          // LCOV_EXCL_START
          throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                                dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
          // LCOV_EXCL_STOP
        }
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                              dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
      }
        // LCOV_EXCL_STOP
      }
    } else if constexpr (std::is_arithmetic_v<OutputType>) {
      switch (data_type) {
      case ASTNodeDataType::bool_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<bool>(result); };
      }
      case ASTNodeDataType::unsigned_int_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<uint64_t>(result); };
      }
      case ASTNodeDataType::int_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<int64_t>(result); };
      }
      case ASTNodeDataType::double_t: {
        return [](DataVariant&& result) -> OutputType { return std::get<double>(result); };
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError("unexpected data_type, cannot convert \"" + dataTypeName(data_type) + "\" to \"" +
                              dataTypeName(ast_node_data_type_from<OutputType>) + "\"");
      }
        // LCOV_EXCL_STOP
      }
    } else {
      static_assert(std::is_arithmetic_v<OutputType>, "unexpected output type");
    }
  }

  [[nodiscard]] PUGS_INLINE static std::function<std::vector<OutputType>(DataVariant&& result)>
  getArrayResultConverter(const ASTNodeDataType& data_type)
  {
    Assert(data_type == ASTNodeDataType::list_t);

    if constexpr (std::is_arithmetic_v<OutputType>) {
      return [&](DataVariant&& result) -> std::vector<OutputType> {
        return std::visit(
          [&](auto&& value) -> std::vector<OutputType> {
            using ValueType = std::decay_t<decltype(value)>;
            if constexpr (std::is_same_v<ValueType, AggregateDataVariant>) {
              std::vector<OutputType> array(value.size());

              for (size_t i = 0; i < value.size(); ++i) {
                array[i] = std::visit(
                  [&](auto&& value_i) -> OutputType {
                    using Value_I_Type = std::decay_t<decltype(value_i)>;
                    if constexpr (std::is_arithmetic_v<Value_I_Type>) {
                      return value_i;
                    } else {
                      throw UnexpectedError("expecting arithmetic type");
                    }
                  },
                  value[i]);
              }

              return array;
            } else {
              throw UnexpectedError("invalid DataVariant");
            }
          },
          result);
      };
    } else {
      throw NotImplementedError("non-arithmetic tuple type");
    }
  }

  PugsFunctionAdapter() = delete;
};

#endif   // PUGS_FUNCTION_ADAPTER_HPP