#ifndef FUNCTION_PROCESSOR_HPP
#define FUNCTION_PROCESSOR_HPP

#include <language/PEGGrammar.hpp>
#include <language/node_processor/ASTNodeExpressionListProcessor.hpp>
#include <language/node_processor/FunctionArgumentConverter.hpp>
#include <language/node_processor/INodeProcessor.hpp>
#include <language/utils/FunctionTable.hpp>
#include <language/utils/SymbolTable.hpp>

template <typename ReturnType, typename ExpressionValueType>
class FunctionExpressionProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_function_expression;

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    if constexpr (std::is_same_v<ReturnType, ExpressionValueType>) {
      return m_function_expression.execute(exec_policy);
    } else if constexpr (std::is_same_v<AggregateDataVariant, ExpressionValueType>) {
      static_assert(is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>, "unexpected return type");
      ReturnType return_value{};
      auto value = std::get<ExpressionValueType>(m_function_expression.execute(exec_policy));
      if constexpr (is_tiny_vector_v<ReturnType>) {
        for (size_t i = 0; i < ReturnType::Dimension; ++i) {
          std::visit(
            [&](auto&& vi) {
              using Vi_T = std::decay_t<decltype(vi)>;
              if constexpr (std::is_convertible_v<Vi_T, double>) {
                return_value[i] = vi;
              }
            },
            value[i]);
        }
      } else {
        static_assert(is_tiny_matrix_v<ReturnType>);

        for (size_t i = 0, l = 0; i < return_value.numberOfRows(); ++i) {
          for (size_t j = 0; j < return_value.numberOfColumns(); ++j, ++l) {
            std::visit(
              [&](auto&& Aij) {
                using Vi_T = std::decay_t<decltype(Aij)>;
                if constexpr (std::is_convertible_v<Vi_T, double>) {
                  return_value(i, j) = Aij;
                }
              },
              value[l]);
          }
        }
      }
      return return_value;
    } else if constexpr (std::is_same_v<ReturnType, std::string>) {
      return std::to_string(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
    } else if constexpr (std::is_same_v<ExpressionValueType, ZeroType>) {
      return ReturnType{ZeroType::zero};
    } else if constexpr (std::is_convertible_v<ExpressionValueType, ReturnType>) {
      return static_cast<ReturnType>(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
    } else if constexpr (std::is_arithmetic_v<ExpressionValueType> and
                         (is_tiny_vector_v<ReturnType> or is_tiny_matrix_v<ReturnType>)) {
      static_assert(ReturnType::Dimension == 1, "invalid conversion");
      return ReturnType(std::get<ExpressionValueType>(m_function_expression.execute(exec_policy)));
    } else {
      throw UnexpectedError("invalid conversion");
    }
  }

  FunctionExpressionProcessor(ASTNode& function_component_expression)
    : m_function_expression{function_component_expression}
  {}
};

class FunctionProcessor : public INodeProcessor
{
 private:
  ASTNode& m_argument_node;

  const size_t m_context_size;
  const int32_t m_context_id;

  std::vector<std::unique_ptr<IFunctionArgumentConverter>> m_argument_converters;
  std::vector<std::unique_ptr<INodeProcessor>> m_function_expression_processors;

 public:
  void
  addArgumentConverter(std::unique_ptr<IFunctionArgumentConverter>&& argument_converter)
  {
    m_argument_converters.emplace_back(std::move(argument_converter));
  }

  void
  addFunctionExpressionProcessor(std::unique_ptr<INodeProcessor>&& function_processor)
  {
    m_function_expression_processors.emplace_back(std::move(function_processor));
  }

  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    // Context is built in each execution for thread safety: multiple thread can call a function at once
    ExecutionPolicy::Context context{m_context_id, std::make_shared<ExecutionPolicy::Context::Values>(m_context_size)};

    ExecutionPolicy context_exec_policy{exec_policy, context};

    if (m_argument_converters.size() == 1) {
      m_argument_converters[0]->convert(context_exec_policy, m_argument_node.execute(context_exec_policy));
    } else {
      AggregateDataVariant argument_values{
        std::get<AggregateDataVariant>(m_argument_node.execute(context_exec_policy))};

      for (size_t i = 0; i < m_argument_converters.size(); ++i) {
        m_argument_converters[i]->convert(context_exec_policy, std::move(argument_values[i]));
      }
    }

    if (m_function_expression_processors.size() == 1) {
      return m_function_expression_processors[0]->execute(context_exec_policy);
    } else {
      std::vector<DataVariant> list_values;
      list_values.reserve(m_function_expression_processors.size());

      for (auto& function_expression_processor : m_function_expression_processors) {
        list_values.emplace_back(function_expression_processor->execute(context_exec_policy));
      }
      return DataVariant{std::move(list_values)};
    }
  }

  FunctionProcessor(ASTNode& argument_node, SymbolTable::Context context)
    : m_argument_node{argument_node}, m_context_size{context.size()}, m_context_id{context.id()}
  {}
};

#endif   // FUNCTION_PROCESSOR_HPP