#ifndef BUILTIN_FUNCTION_EMBEDDER_HPP
#define BUILTIN_FUNCTION_EMBEDDER_HPP

#include <language/utils/ASTNodeDataType.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/DataHandler.hpp>
#include <language/utils/DataVariant.hpp>
#include <language/utils/FunctionTable.hpp>
#include <utils/Demangle.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsTraits.hpp>

#include <functional>
#include <memory>
#include <vector>

class IBuiltinFunctionEmbedder
{
 public:
  virtual size_t numberOfParameters() const = 0;

  virtual ASTNodeDataType getReturnDataType() const = 0;

  virtual std::vector<ASTNodeDataType> getParameterDataTypes() const = 0;

  virtual DataVariant apply(const std::vector<DataVariant>&) const = 0;

  IBuiltinFunctionEmbedder() = default;

  IBuiltinFunctionEmbedder(const IBuiltinFunctionEmbedder&) = delete;
  IBuiltinFunctionEmbedder(IBuiltinFunctionEmbedder&&)      = delete;

  virtual ~IBuiltinFunctionEmbedder() = default;
};

template <typename FX, typename... Args>
class BuiltinFunctionEmbedderBase;

template <typename FX, typename... Args>
class BuiltinFunctionEmbedderBase<FX(Args...)> : public IBuiltinFunctionEmbedder
{
 protected:
  template <typename ValueT>
  PUGS_INLINE void constexpr _check_value_type() const
  {
    if constexpr (std::is_lvalue_reference_v<ValueT>) {
      static_assert(std::is_const_v<std::remove_reference_t<ValueT>>,
                    "builtin function return values are non mutable use 'const' when passing references");
    }

    if constexpr (is_std_ptr_v<ValueT>) {
      static_assert(std::is_const_v<typename ValueT::element_type>,
                    "builtin function return values are non mutable. For instance use std::shared_ptr<const T>");
    }
  }

  template <size_t I>
  PUGS_INLINE void constexpr _check_value() const
  {
    using ValueN_T = std::tuple_element_t<I, FX>;
    _check_value_type<ValueN_T>();

    if (ast_node_data_type_from<std::remove_cv_t<std::remove_reference_t<ValueN_T>>> == ASTNodeDataType::undefined_t) {
      throw std::invalid_argument(std::string{"cannot bind C++ to language.\nnote: return value number "} +
                                  std::to_string(I + 1) + std::string{" has no associated language type: "} +
                                  demangle<ValueN_T>());
    }
  }

  template <size_t... I>
  PUGS_INLINE void constexpr _check_tuple_value(std::index_sequence<I...>) const
  {
    (_check_value<I>(), ...);
  }

  PUGS_INLINE void constexpr _check_return_type()
  {
    if constexpr (is_std_tuple_v<FX>) {
      constexpr size_t N  = std::tuple_size_v<FX>;
      using IndexSequence = std::make_index_sequence<N>;
      _check_tuple_value(IndexSequence{});
    } else {
      if (ast_node_data_type_from<std::remove_cv_t<std::remove_reference_t<FX>>> == ASTNodeDataType::undefined_t) {
        throw std::invalid_argument(
          std::string{"cannot bind C++ to language.\nnote: return value has no associated language type: "} +
          demangle<FX>());
      }
      _check_value_type<FX>();
    }
  }

  template <typename T>
  PUGS_INLINE ASTNodeDataType
  _getDataType() const
  {
    Assert(ast_node_data_type_from<T> != ASTNodeDataType::undefined_t);
    return ast_node_data_type_from<T>;
  }

  template <typename TupleT, size_t I>
  PUGS_INLINE ASTNodeDataType
  _getOneElementDataType() const
  {
    using ArgN_T = std::decay_t<decltype(std::get<I>(TupleT{}))>;
    return this->template _getDataType<ArgN_T>();
  }

  template <size_t... I>
  PUGS_INLINE std::vector<std::shared_ptr<const ASTNodeDataType>> _getCompoundDataTypes(std::index_sequence<I...>) const
  {
    std::vector<std::shared_ptr<const ASTNodeDataType>> compound_type_list;
    (compound_type_list.push_back(std::make_shared<ASTNodeDataType>(this->_getOneElementDataType<FX, I>())), ...);
    return compound_type_list;
  }

  template <typename T>
  PUGS_INLINE EmbeddedData
  _createHandler(std::shared_ptr<T> data) const
  {
    return EmbeddedData{std::make_shared<DataHandler<T>>(data)};
  }

  template <typename T>
  PUGS_INLINE std::vector<EmbeddedData>
  _createHandler(std::vector<std::shared_ptr<T>> data) const
  {
    std::vector<EmbeddedData> embedded(data.size());
    for (size_t i_data = 0; i_data < data.size(); ++i_data) {
      embedded[i_data] = EmbeddedData(std::make_shared<DataHandler<T>>(data[i_data]));
    }
    return embedded;
  }

  template <typename ResultT>
  PUGS_INLINE DataVariant
  _resultToDataVariant(ResultT&& result) const
  {
    if constexpr (is_data_variant_v<std::decay_t<ResultT>>) {
      return std::move(result);
    } else {
      return _createHandler(std::move(result));
    }
  }

 public:
  PUGS_INLINE ASTNodeDataType
  getReturnDataType() const final
  {
    if constexpr (is_std_tuple_v<FX>) {
      constexpr size_t N  = std::tuple_size_v<FX>;
      using IndexSequence = std::make_index_sequence<N>;
      return ASTNodeDataType::build<ASTNodeDataType::list_t>(this->_getCompoundDataTypes(IndexSequence{}));
    } else {
      return this->_getDataType<FX>();
    }
  }

  BuiltinFunctionEmbedderBase()
  {
    this->_check_return_type();
  }

  BuiltinFunctionEmbedderBase(const BuiltinFunctionEmbedderBase&) = delete;
  BuiltinFunctionEmbedderBase(BuiltinFunctionEmbedderBase&&)      = delete;

  virtual ~BuiltinFunctionEmbedderBase() = default;
};

template <typename FX, typename... Args>
class BuiltinFunctionEmbedder
{
  static_assert(std::is_class_v<BuiltinFunctionEmbedder<FX, Args...>>,
                "wrong template parameters do not use <FX, Args...>, use <FX(Args...)>");
};

template <typename T>
inline constexpr bool is_const_ref_or_non_ref = (std::is_const_v<T> and std::is_lvalue_reference_v<T>) or
                                                (not std::is_lvalue_reference_v<T>);

template <typename FX, typename... Args>
class BuiltinFunctionEmbedder<FX(Args...)> : public BuiltinFunctionEmbedderBase<FX(Args...)>
{
 private:
  std::function<FX(Args...)> m_f;
  using ArgsTuple = std::tuple<std::decay_t<Args>...>;

  template <size_t I>
  PUGS_INLINE void constexpr _check_arg() const
  {
    using ArgN_T = std::tuple_element_t<I, std::tuple<Args...>>;
    if constexpr (std::is_lvalue_reference_v<ArgN_T>) {
      static_assert(std::is_const_v<std::remove_reference_t<ArgN_T>>,
                    "builtin function arguments are non mutable use 'const' when passing references");
    }

    if constexpr (is_std_ptr_v<ArgN_T>) {
      static_assert(std::is_const_v<typename ArgN_T::element_type>,
                    "builtin function arguments are non mutable. For instance use std::shared_ptr<const T>");
    }

    if (ast_node_data_type_from<std::remove_cv_t<std::remove_reference_t<ArgN_T>>> == ASTNodeDataType::undefined_t) {
      throw std::invalid_argument(std::string{"cannot bind C++ to language.\nnote: argument number "} +
                                  std::to_string(I + 1) + std::string{" has no associated language type: "} +
                                  demangle<ArgN_T>());
    }
  }

  template <size_t... I>
  PUGS_INLINE void constexpr _check_arg_list(std::index_sequence<I...>) const
  {
    (_check_arg<I>(), ...);
  }

  template <size_t I>
  PUGS_INLINE void
  _copyValue(ArgsTuple& t, const std::vector<DataVariant>& v) const
  {
    std::visit(
      [&](auto&& v_i) {
        using Ti_Type = std::decay_t<decltype(std::get<I>(t))>;
        using Vi_Type = std::decay_t<decltype(v_i)>;

        if constexpr ((std::is_same_v<Vi_Type, Ti_Type>)) {
          std::get<I>(t) = v_i;
        } else if constexpr ((std::is_arithmetic_v<Vi_Type>)and(std::is_arithmetic_v<Ti_Type> or
                                                                std::is_same_v<Ti_Type, std::string>)) {
          std::get<I>(t) = v_i;
        } else if constexpr (is_shared_ptr_v<Ti_Type>) {
          if constexpr (std::is_same_v<Vi_Type, EmbeddedData>) {
            using Ti_handeled_type = typename Ti_Type::element_type;
            try {
              auto& data_handler = dynamic_cast<const DataHandler<Ti_handeled_type>&>(v_i.get());
              std::get<I>(t)     = data_handler.data_ptr();
            }
            catch (std::bad_cast&) {
              throw UnexpectedError("unexpected argument types while casting: invalid EmbeddedData type, expecting " +
                                    demangle<DataHandler<Ti_handeled_type>>());
            }
          } else {
            throw UnexpectedError("unexpected argument types while casting: expecting EmbeddedData");
          }
        } else if constexpr (std::is_same_v<Vi_Type, std::vector<EmbeddedData>>) {
          if constexpr (is_std_vector_v<Ti_Type>) {
            using Ti_value_type = typename Ti_Type::value_type;
            if constexpr (is_shared_ptr_v<Ti_value_type>) {
              static_assert(is_shared_ptr_v<Ti_value_type>, "expecting shared_ptr");

              using Ti_handeled_type = typename Ti_value_type::element_type;
              std::get<I>(t).resize(v_i.size());
              for (size_t j = 0; j < v_i.size(); ++j) {
                try {
                  auto& data_handler = dynamic_cast<const DataHandler<Ti_handeled_type>&>(v_i[j].get());
                  std::get<I>(t)[j]  = data_handler.data_ptr();
                }
                catch (std::bad_cast&) {
                  throw UnexpectedError(
                    "unexpected argument types while casting: invalid EmbeddedData type, expecting " +
                    demangle<DataHandler<Ti_handeled_type>>());
                }
              }
            } else {
              throw UnexpectedError("unexpected argument types while casting \"" + demangle<Vi_Type>() + "\" to \"" +
                                    demangle<Ti_Type>() + '"');
            }
          } else {
            throw UnexpectedError("unexpected argument types while casting \"" + demangle<Vi_Type>() + "\" to \"" +
                                  demangle<Ti_Type>() + '"');
          }
        } else {
          throw UnexpectedError("unexpected argument types while casting \"" + demangle<Vi_Type>() + "\" to \"" +
                                demangle<Ti_Type>() + '"');
        }
      },
      v[I]);
  }

  template <size_t... I>
  PUGS_INLINE void
  _copyFromVector(ArgsTuple& t, const std::vector<DataVariant>& v, std::index_sequence<I...>) const
  {
    Assert(sizeof...(Args) == v.size());
    (_copyValue<I>(t, v), ...);
  }

  template <size_t... I>
  PUGS_INLINE std::vector<ASTNodeDataType> _getParameterDataTypes(std::index_sequence<I...>) const
  {
    std::vector<ASTNodeDataType> parameter_type_list;
    (parameter_type_list.push_back(this->template _getOneElementDataType<ArgsTuple, I>()), ...);
    return parameter_type_list;
  }

  PUGS_INLINE
  AggregateDataVariant
  _applyToAggregate(const ArgsTuple& t) const
  {
    auto tuple_result = std::apply(m_f, t);
    std::vector<DataVariant> vector_result;
    vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>);

    std::
      apply([&](auto&&... result) { ((vector_result.emplace_back(this->template _resultToDataVariant(result))), ...); },
            tuple_result);

    return AggregateDataVariant{std::move(vector_result)};
  }

 public:
  PUGS_INLINE std::vector<ASTNodeDataType>
  getParameterDataTypes() const final
  {
    constexpr size_t N  = std::tuple_size_v<ArgsTuple>;
    using IndexSequence = std::make_index_sequence<N>;

    return this->_getParameterDataTypes(IndexSequence{});
  }

  PUGS_INLINE size_t
  numberOfParameters() const final
  {
    return sizeof...(Args);
  }

  PUGS_INLINE
  DataVariant
  apply(const std::vector<DataVariant>& x) const final
  {
    constexpr size_t N = std::tuple_size_v<ArgsTuple>;
    ArgsTuple t;
    using IndexSequence = std::make_index_sequence<N>;

    this->_copyFromVector(t, x, IndexSequence{});
    if constexpr (is_data_variant_v<FX>) {
      return {std::apply(m_f, t)};
    } else if constexpr (is_std_tuple_v<FX>) {
      return this->_applyToAggregate(t);
    } else if constexpr (std::is_same_v<FX, void>) {
      std::apply(m_f, t);
      return {};
    } else {
      return this->template _createHandler(std::apply(m_f, t));
    }
  }

  BuiltinFunctionEmbedder(std::function<FX(Args...)> f) : m_f{f}
  {
    using IndexSequence = std::make_index_sequence<std::tuple_size_v<ArgsTuple>>;
    this->_check_arg_list(IndexSequence{});
  }

  BuiltinFunctionEmbedder(const BuiltinFunctionEmbedder&) = delete;
  BuiltinFunctionEmbedder(BuiltinFunctionEmbedder&&)      = delete;

  ~BuiltinFunctionEmbedder() = default;
};

template <typename FX>
class BuiltinFunctionEmbedder<FX, void>
{
  static_assert(std::is_class_v<BuiltinFunctionEmbedder<FX, void>>,
                "wrong template parameters do not use <FX, void>, use <FX(void)>");
};

template <typename FX>
class BuiltinFunctionEmbedder<FX(void)> : public BuiltinFunctionEmbedderBase<FX(void)>
{
 private:
  std::function<FX(void)> m_f;

  PUGS_INLINE
  AggregateDataVariant
  _applyToAggregate() const
  {
    auto tuple_result = m_f();
    std::vector<DataVariant> vector_result;
    vector_result.reserve(std::tuple_size_v<decltype(tuple_result)>);

    std::
      apply([&](auto&&... result) { ((vector_result.emplace_back(this->template _resultToDataVariant(result))), ...); },
            tuple_result);

    return AggregateDataVariant{std::move(vector_result)};
  }

 public:
  PUGS_INLINE std::vector<ASTNodeDataType>
  getParameterDataTypes() const final
  {
    return {};
  }

  PUGS_INLINE size_t
  numberOfParameters() const final
  {
    return 0;
  }

  PUGS_INLINE
  DataVariant
  apply(const std::vector<DataVariant>&) const final
  {
    if constexpr (is_data_variant_v<FX>) {
      return {m_f()};
    } else if constexpr (is_std_tuple_v<FX>) {
      return this->_applyToAggregate();
    } else if constexpr (std::is_same_v<FX, void>) {
      m_f();
      return {};
    } else {
      return EmbeddedData(this->template _createHandler(m_f()));
    }
  }

  BuiltinFunctionEmbedder(std::function<FX(void)> f) : m_f{f} {}

  BuiltinFunctionEmbedder(const BuiltinFunctionEmbedder&) = delete;
  BuiltinFunctionEmbedder(BuiltinFunctionEmbedder&&)      = delete;

  ~BuiltinFunctionEmbedder() = default;
};

#endif   //  BUILTIN_FUNCTION_EMBEDDER_HPP