#ifndef FUNCTION_PROCESSOR_HPP
#define FUNCTION_PROCESSOR_HPP

#include <language/PEGGrammar.hpp>
#include <language/node_processor/ASTNodeExpressionListProcessor.hpp>
#include <language/node_processor/FunctionArgumentConverter.hpp>
#include <language/node_processor/INodeProcessor.hpp>
#include <language/utils/FunctionTable.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/Stringify.hpp>

template <typename ReturnType, typename ExpressionValueType>
class FunctionExpressionProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_function_expression;

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) {
      return m_function_expression.execute(exec_policy);
    } else if constexpr (std::is_same_v<ReturnType, std::string>) {
      return stringify(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
    } else if constexpr (std::is_same_v<ExpressionValueType, ZeroType>) {
      return ReturnType{ZeroType::zero};
    } else if constexpr (std::is_convertible_v<ExpressionValueType, ReturnType>) {
      auto expression_value        = m_function_expression.execute(exec_policy);
      const ExpressionValueType& v = std::get<ExpressionValueType>(expression_value);
      if constexpr (std::is_same_v<ReturnType, uint64_t> and std::is_same_v<ExpressionValueType, int64_t>) {
        if (v < 0) {
          throw std::domain_error("trying to convert negative value (" + stringify(v) + ")");
        }
      }
      return static_cast<ReturnType>(v);
    } else if constexpr (std::is_arithmetic_v<ExpressionValueType> and
                         (is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>)) {
      static_assert(ReturnType::Dimension == 1, "invalid conversion");
      return ReturnType(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
    } else {
      // LCOV_EXCL_START
      throw UnexpectedError("invalid conversion");
      // LCOV_EXCL_STOP
    }
  }

  FunctionExpressionProcessor(ASTNode& function_component_expression)
    : m_function_expression{function_component_expression}
  {}
};

class FunctionProcessor : public INodeProcessor
{
 private:
  ASTNode& m_argument_node;

  const size_t m_context_size;
  const int32_t m_context_id;

  std::vector<std::unique_ptr<IFunctionArgumentConverter>> m_argument_converters;
  std::vector<std::unique_ptr<INodeProcessor>> m_function_expression_processors;

 public:
  void
  addArgumentConverter(std::unique_ptr<IFunctionArgumentConverter>&& argument_converter)
  {
    m_argument_converters.emplace_back(std::move(argument_converter));
  }

  void
  addFunctionExpressionProcessor(std::unique_ptr<INodeProcessor>&& function_processor)
  {
    m_function_expression_processors.emplace_back(std::move(function_processor));
  }

  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    // Context is built in each execution for thread safety: multiple thread can call a function at once
    ExecutionPolicy::Context context{m_context_id, std::make_shared<ExecutionPolicy::Context::Values>(m_context_size)};

    ExecutionPolicy context_exec_policy{exec_policy, context};

    if (m_argument_converters.size() == 1) {
      try {
        m_argument_converters[0]->convert(context_exec_policy, m_argument_node.execute(context_exec_policy));
      }
      catch (std::domain_error& e) {
        throw ParseError(e.what(), m_argument_node.begin());
      }
    } else {
      AggregateDataVariant argument_values{
        std::get<AggregateDataVariant>(m_argument_node.execute(context_exec_policy))};

      for (size_t i = 0; i < m_argument_converters.size(); ++i) {
        try {
          m_argument_converters[i]->convert(context_exec_policy, std::move(argument_values[i]));
        }
        catch (std::domain_error& e) {
          throw ParseError(e.what(), m_argument_node.children[i]->begin());
        }
      }
    }

    try {
      if (m_function_expression_processors.size() == 1) {
        return m_function_expression_processors[0]->execute(context_exec_policy);
      } else {
        std::vector<DataVariant> list_values;
        list_values.reserve(m_function_expression_processors.size());

        for (auto& function_expression_processor : m_function_expression_processors) {
          list_values.emplace_back(function_expression_processor->execute(context_exec_policy));
        }
        return AggregateDataVariant{std::move(list_values)};
      }
    }
    catch (std::domain_error& e) {
      throw ParseError(e.what(), m_argument_node.begin());
    }
  }

  FunctionProcessor(ASTNode& argument_node, SymbolTable::Context context)
    : m_argument_node{argument_node}, m_context_size{context.size()}, m_context_id{context.id()}
  {}
};

#endif   // FUNCTION_PROCESSOR_HPP