#ifndef AFFECTATION_PROCESSOR_HPP
#define AFFECTATION_PROCESSOR_HPP

#include <node_processor/INodeProcessor.hpp>

#include <SymbolTable.hpp>

template <typename Op>
struct AffOp;

template <>
struct AffOp<language::multiplyeq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    a *= b;
  }
};

template <>
struct AffOp<language::divideeq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    a /= b;
  }
};

template <>
struct AffOp<language::pluseq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    a += b;
  }
};

template <>
struct AffOp<language::minuseq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    a -= b;
  }
};

struct IAffectationExecutor
{
  virtual void affect(ExecutionPolicy& exec_policy, DataVariant&& rhs) = 0;

  virtual ~IAffectationExecutor() = default;
};

template <typename OperatorT, typename ValueT, typename DataT>
class AffectationExecutor final : public IAffectationExecutor
{
 private:
  ValueT& m_lhs;

  static inline const bool _is_defined{[] {
    if constexpr (std::is_same_v<std::decay_t<ValueT>, bool>) {
      if constexpr (not std::is_same_v<OperatorT, language::eq_op>) {
        return false;
      }
    }
    return true;
  }()};

 public:
  AffectationExecutor(ASTNode& node, ValueT& lhs) : m_lhs(lhs)
  {
    // LCOV_EXCL_START
    if constexpr (not _is_defined) {
      throw parse_error("unexpected error: invalid operands to affectation expression", std::vector{node.begin()});
    }
    // LCOV_EXCL_STOP
  }

  PUGS_INLINE void
  affect(ExecutionPolicy&, DataVariant&& rhs)
  {
    if constexpr (_is_defined) {
      if constexpr (not std::is_same_v<DataT, ZeroType>) {
        if constexpr (std::is_same_v<ValueT, std::string>) {
          if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
            if constexpr (std::is_same_v<std::string, DataT>) {
              m_lhs = std::get<DataT>(rhs);
            } else if constexpr (std::is_arithmetic_v<DataT>) {
              m_lhs = std::to_string(std::get<DataT>(rhs));
            } else {
              std::ostringstream os;
              os << std::get<DataT>(rhs) << std::ends;
              m_lhs = os.str();
            }
          } else {
            static_assert(std::is_same_v<OperatorT, language::pluseq_op>, "unexpected operator type");
            if constexpr (std::is_same_v<std::string, DataT>) {
              m_lhs += std::get<std::string>(rhs);
            } else if constexpr (std::is_arithmetic_v<DataT>) {
              m_lhs += std::to_string(std::get<DataT>(rhs));
            } else {
              std::ostringstream os;
              os << std::get<DataT>(rhs) << std::ends;
              m_lhs += os.str();
            }
          }
        } else {
          if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
            if constexpr (std::is_same_v<ValueT, DataT>) {
              m_lhs = std::get<DataT>(rhs);
            } else {
              m_lhs = static_cast<ValueT>(std::get<DataT>(rhs));
            }
          } else {
            AffOp<OperatorT>().eval(m_lhs, std::get<DataT>(rhs));
          }
        }
      } else if (std::is_same_v<OperatorT, language::eq_op>) {
        m_lhs = ValueT{zero};
      } else {
        static_assert(std::is_same_v<OperatorT, language::eq_op>, "unexpected operator type");
      }
    }
  }
};

template <typename OperatorT, typename ArrayT, typename ValueT, typename DataT>
class ComponentAffectationExecutor final : public IAffectationExecutor
{
 private:
  ArrayT& m_lhs_array;
  ASTNode& m_index_expression;

  static inline const bool _is_defined{[] {
    if constexpr (not std::is_same_v<typename ArrayT::data_type, ValueT>) {
      return false;
    } else if constexpr (std::is_same_v<std::decay_t<ValueT>, bool>) {
      if constexpr (not std::is_same_v<OperatorT, language::eq_op>) {
        return false;
      }
    }
    return true;
  }()};

 public:
  ComponentAffectationExecutor(ASTNode& node, ArrayT& lhs_array, ASTNode& index_expression)
    : m_lhs_array{lhs_array}, m_index_expression{index_expression}
  {
    // LCOV_EXCL_START
    if constexpr (not _is_defined) {
      throw parse_error("unexpected error: invalid operands to affectation expression", std::vector{node.begin()});
    }
    // LCOV_EXCL_STOP
  }

  PUGS_INLINE void
  affect(ExecutionPolicy& exec_policy, DataVariant&& rhs)
  {
    if constexpr (_is_defined) {
      const int64_t index_value = [&](DataVariant&& value_variant) -> int64_t {
        int64_t index_value = 0;
        std::visit(
          [&](auto&& value) {
            using IndexValueT = std::decay_t<decltype(value)>;
            if constexpr (std::is_integral_v<IndexValueT>) {
              index_value = value;
            } else {
              throw parse_error("unexpected error: invalid index type", std::vector{m_index_expression.begin()});
            }
          },
          value_variant);
        return index_value;
      }(m_index_expression.execute(exec_policy));

      if constexpr (std::is_same_v<ValueT, std::string>) {
        if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
          if constexpr (std::is_same_v<std::string, DataT>) {
            m_lhs_array[index_value] = std::get<DataT>(rhs);
          } else {
            m_lhs_array[index_value] = std::to_string(std::get<DataT>(rhs));
          }
        } else {
          static_assert(std::is_same_v<OperatorT, language::pluseq_op>, "unexpected operator type");
          if constexpr (std::is_same_v<std::string, DataT>) {
            m_lhs_array[index_value] += std::get<std::string>(rhs);
          } else {
            m_lhs_array[index_value] += std::to_string(std::get<DataT>(rhs));
          }
        }
      } else {
        if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
          if constexpr (std::is_same_v<ValueT, DataT>) {
            m_lhs_array[index_value] = std::get<DataT>(rhs);
          } else {
            m_lhs_array[index_value] = static_cast<ValueT>(std::get<DataT>(rhs));
          }
        } else {
          AffOp<OperatorT>().eval(m_lhs_array[index_value], std::get<DataT>(rhs));
        }
      }
    }
  }
};

template <typename OperatorT, typename ValueT, typename DataT>
class AffectationProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_node;

  std::unique_ptr<IAffectationExecutor> m_affectation_executor;

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    m_affectation_executor->affect(exec_policy, m_node.children[1]->execute(exec_policy));

    return {};
  }

  AffectationProcessor(ASTNode& node) : m_node{node}
  {
    if (node.children[0]->is_type<language::name>()) {
      const std::string& symbol = m_node.children[0]->string();
      auto [i_symbol, found]    = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin());
      Assert(found);
      DataVariant& value = i_symbol->attributes().value();

      if (not std::holds_alternative<ValueT>(value)) {
        value = ValueT{};
      }

      using AffectationExecutorT = AffectationExecutor<OperatorT, ValueT, DataT>;
      m_affectation_executor     = std::make_unique<AffectationExecutorT>(m_node, std::get<ValueT>(value));
    } else if (node.children[0]->is_type<language::subscript_expression>()) {
      auto& array_subscript_expression = *node.children[0];

      auto& array_expression = *array_subscript_expression.children[0];
      Assert(array_expression.is_type<language::name>());

      const std::string& symbol = array_expression.string();

      auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, array_subscript_expression.begin());
      Assert(found);
      DataVariant& value = i_symbol->attributes().value();

      if (array_expression.m_data_type != ASTNodeDataType::vector_t) {
        throw parse_error("unexpected error: invalid lhs (expecting R^d)",
                          std::vector{array_subscript_expression.begin()});
      }

      auto& index_expression = *array_subscript_expression.children[1];

      switch (array_expression.m_data_type.dimension()) {
      case 1: {
        using ArrayTypeT = TinyVector<1>;
        if (not std::holds_alternative<ArrayTypeT>(value)) {
          value = ArrayTypeT{};
        }
        using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
        m_affectation_executor =
          std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression);
        break;
      }
      case 2: {
        using ArrayTypeT = TinyVector<2>;
        if (not std::holds_alternative<ArrayTypeT>(value)) {
          value = ArrayTypeT{};
        }
        using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
        m_affectation_executor =
          std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression);
        break;
      }
      case 3: {
        using ArrayTypeT = TinyVector<3>;
        if (not std::holds_alternative<ArrayTypeT>(value)) {
          value = ArrayTypeT{};
        }
        using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
        m_affectation_executor =
          std::make_unique<AffectationExecutorT>(node, std::get<ArrayTypeT>(value), index_expression);
        break;
      }
      default: {
        throw parse_error("unexpected error: invalid vector dimension",
                          std::vector{array_subscript_expression.begin()});
      }
      }

    } else {
      throw parse_error("unexpected error: invalid lhs", std::vector{node.children[0]->begin()});
    }
  }
};

template <typename OperatorT, typename ValueT>
class AffectationFromListProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_node;

  DataVariant* m_lhs;

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_node.children[1]->execute(exec_policy));

    static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator for list to vectors");

    ValueT v;
    for (size_t i = 0; i < v.dimension(); ++i) {
      std::visit(
        [&](auto&& child_value) {
          using T = std::decay_t<decltype(child_value)>;
          if constexpr (std::is_same_v<T, bool> or std::is_same_v<T, uint64_t> or std::is_same_v<T, int64_t> or
                        std::is_same_v<T, double>) {
            v[i] = child_value;
          } else {
            throw parse_error("unexpected error: unexpected right hand side type in affectation", m_node.begin());
          }
        },
        children_values[i]);
    }

    *m_lhs = v;
    return {};
  }

  AffectationFromListProcessor(ASTNode& node) : m_node{node}
  {
    const std::string& symbol = m_node.children[0]->string();
    auto [i_symbol, found]    = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin());
    Assert(found);

    m_lhs = &i_symbol->attributes().value();
  }
};

template <typename ValueT>
class AffectationFromZeroProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_node;

  DataVariant* m_lhs;

 public:
  DataVariant
  execute(ExecutionPolicy&)
  {
    *m_lhs = ValueT{zero};
    return {};
  }

  AffectationFromZeroProcessor(ASTNode& node) : m_node{node}
  {
    const std::string& symbol = m_node.children[0]->string();
    auto [i_symbol, found]    = m_node.m_symbol_table->find(symbol, m_node.children[0]->begin());
    Assert(found);

    m_lhs = &i_symbol->attributes().value();
  }
};

template <typename OperatorT>
class ListAffectationProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_node;

  std::vector<std::unique_ptr<IAffectationExecutor>> m_affectation_executor_list;

 public:
  template <typename ValueT, typename DataT>
  void
  add(ASTNode& lhs_node)
  {
    using AffectationExecutorT = AffectationExecutor<OperatorT, ValueT, DataT>;

    if (lhs_node.is_type<language::name>()) {
      const std::string& symbol = lhs_node.string();
      auto [i_symbol, found]    = m_node.m_symbol_table->find(symbol, m_node.children[0]->end());
      Assert(found);
      DataVariant& value = i_symbol->attributes().value();

      if (not std::holds_alternative<ValueT>(value)) {
        value = ValueT{};
      }

      m_affectation_executor_list.emplace_back(std::make_unique<AffectationExecutorT>(m_node, std::get<ValueT>(value)));
    } else if (lhs_node.is_type<language::subscript_expression>()) {
      auto& array_subscript_expression = lhs_node;

      auto& array_expression = *array_subscript_expression.children[0];
      Assert(array_expression.is_type<language::name>());

      const std::string& symbol = array_expression.string();

      auto [i_symbol, found] = m_node.m_symbol_table->find(symbol, array_subscript_expression.begin());
      Assert(found);
      DataVariant& value = i_symbol->attributes().value();

      if (array_expression.m_data_type != ASTNodeDataType::vector_t) {
        throw parse_error("unexpected error: invalid lhs (expecting R^d)",
                          std::vector{array_subscript_expression.begin()});
      }

      auto& index_expression = *array_subscript_expression.children[1];

      switch (array_expression.m_data_type.dimension()) {
      case 1: {
        using ArrayTypeT = TinyVector<1>;
        if (not std::holds_alternative<ArrayTypeT>(value)) {
          value = ArrayTypeT{};
        }
        using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
        m_affectation_executor_list.emplace_back(
          std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression));
        break;
      }
      case 2: {
        using ArrayTypeT = TinyVector<2>;
        if (not std::holds_alternative<ArrayTypeT>(value)) {
          value = ArrayTypeT{};
        }
        using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
        m_affectation_executor_list.emplace_back(
          std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression));
        break;
      }
      case 3: {
        using ArrayTypeT = TinyVector<3>;
        if (not std::holds_alternative<ArrayTypeT>(value)) {
          value = ArrayTypeT{};
        }
        using AffectationExecutorT = ComponentAffectationExecutor<OperatorT, ArrayTypeT, ValueT, DataT>;
        m_affectation_executor_list.emplace_back(
          std::make_unique<AffectationExecutorT>(lhs_node, std::get<ArrayTypeT>(value), index_expression));
        break;
      }
      default: {
        throw parse_error("unexpected error: invalid vector dimension",
                          std::vector{array_subscript_expression.begin()});
      }
      }
    } else {
      throw parse_error("unexpected error: invalid left hand side", std::vector{lhs_node.begin()});
    }
  }

  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_node.children[1]->execute(exec_policy));
    Assert(m_affectation_executor_list.size() == children_values.size());

    for (size_t i = 0; i < m_affectation_executor_list.size(); ++i) {
      m_affectation_executor_list[i]->affect(exec_policy, std::move(children_values[i]));
    }

    return {};
  }

  ListAffectationProcessor(ASTNode& node) : m_node{node} {}
};

#endif   // AFFECTATION_PROCESSOR_HPP