#ifndef BINARY_OPERATOR_PROCESSOR_BUILDER_HPP
#define BINARY_OPERATOR_PROCESSOR_BUILDER_HPP

#include <algebra/TinyVector.hpp>
#include <language/PEGGrammar.hpp>
#include <language/node_processor/BinaryExpressionProcessor.hpp>
#include <language/utils/ASTNodeDataTypeTraits.hpp>
#include <language/utils/IBinaryOperatorProcessorBuilder.hpp>
#include <language/utils/ParseError.hpp>

#include <type_traits>

template <typename DataType>
class DataHandler;

template <typename OperatorT, typename ValueT, typename A_DataT, typename B_DataT>
class BinaryOperatorProcessorBuilder final : public IBinaryOperatorProcessorBuilder
{
 public:
  BinaryOperatorProcessorBuilder() = default;

  ASTNodeDataType
  getDataTypeOfA() const
  {
    return ast_node_data_type_from<A_DataT>;
  }

  ASTNodeDataType
  getDataTypeOfB() const
  {
    return ast_node_data_type_from<B_DataT>;
  }

  ASTNodeDataType
  getReturnValueType() const
  {
    return ast_node_data_type_from<ValueT>;
  }

  std::unique_ptr<INodeProcessor>
  getNodeProcessor(ASTNode& node) const
  {
    return std::make_unique<BinaryExpressionProcessor<OperatorT, ValueT, A_DataT, B_DataT>>(node);
  }
};

template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, std::shared_ptr<B_DataT>>
  final : public INodeProcessor
{
 private:
  ASTNode& m_node;

  PUGS_INLINE DataVariant
  _eval(const DataVariant& a, const DataVariant& b)
  {
    const auto& embedded_a = std::get<EmbeddedData>(a);
    const auto& embedded_b = std::get<EmbeddedData>(b);

    std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr();

    std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr();

    return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_ptr)));
  }

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    try {
      return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
    }
    catch (const NormalError& error) {
      throw ParseError(error.what(), m_node.begin());
    }
  }

  BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};

template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, A_DataT, std::shared_ptr<B_DataT>> final
  : public INodeProcessor
{
 private:
  ASTNode& m_node;

  PUGS_INLINE DataVariant
  _eval(const DataVariant& a, const DataVariant& b)
  {
    if constexpr ((std::is_arithmetic_v<A_DataT>) or (is_tiny_vector_v<A_DataT>) or (is_tiny_matrix_v<A_DataT>)) {
      const auto& a_value    = std::get<A_DataT>(a);
      const auto& embedded_b = std::get<EmbeddedData>(b);

      std::shared_ptr b_ptr = dynamic_cast<const DataHandler<B_DataT>&>(embedded_b.get()).data_ptr();

      return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_value, b_ptr)));
    } else {
      static_assert(std::is_arithmetic_v<A_DataT>, "invalid left hand side type");
    }
  }

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    try {
      return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
    }
    catch (const NormalError& error) {
      throw ParseError(error.what(), m_node.begin());
    }
  }

  BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};

template <typename BinaryOpT, typename ValueT, typename A_DataT, typename B_DataT>
struct BinaryExpressionProcessor<BinaryOpT, std::shared_ptr<ValueT>, std::shared_ptr<A_DataT>, B_DataT> final
  : public INodeProcessor
{
 private:
  ASTNode& m_node;

  PUGS_INLINE DataVariant
  _eval(const DataVariant& a, const DataVariant& b)
  {
    if constexpr ((std::is_arithmetic_v<B_DataT>) or (is_tiny_matrix_v<B_DataT>) or (is_tiny_vector_v<B_DataT>)) {
      const auto& embedded_a = std::get<EmbeddedData>(a);
      const auto& b_value    = std::get<B_DataT>(b);

      std::shared_ptr a_ptr = dynamic_cast<const DataHandler<A_DataT>&>(embedded_a.get()).data_ptr();

      return EmbeddedData(std::make_shared<DataHandler<ValueT>>(BinOp<BinaryOpT>().eval(a_ptr, b_value)));
    } else {
      static_assert(std::is_arithmetic_v<B_DataT>, "invalid right hand side type");
    }
  }

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    try {
      return this->_eval(m_node.children[0]->execute(exec_policy), m_node.children[1]->execute(exec_policy));
    }
    catch (const NormalError& error) {
      throw ParseError(error.what(), m_node.begin());
    }
  }

  BinaryExpressionProcessor(ASTNode& node) : m_node{node} {}
};

#endif   // BINARY_OPERATOR_PROCESSOR_BUILDER_HPP