#ifndef AFFECTATION_PROCESSOR_HPP
#define AFFECTATION_PROCESSOR_HPP

#include <language/PEGGrammar.hpp>
#include <language/node_processor/INodeProcessor.hpp>
#include <language/utils/ParseError.hpp>
#include <language/utils/SymbolTable.hpp>
#include <utils/Exceptions.hpp>
#include <utils/PugsTraits.hpp>

#include <exception>

template <typename Op>
struct AffOp;

template <>
struct AffOp<language::multiplyeq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    if constexpr (std::is_same_v<uint64_t, A> and std::is_same_v<int64_t, B>) {
      if (b < 0) {
        throw std::domain_error("trying to affect negative value (" + std::to_string(b) + ")");
      }
    }
    a *= b;
  }
};

template <>
struct AffOp<language::divideeq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    if constexpr (std::is_same_v<uint64_t, A> and std::is_same_v<int64_t, B>) {
      if (b < 0) {
        throw std::domain_error("trying to affect negative value (" + std::to_string(b) + ")");
      }
    }
    a /= b;
  }
};

template <>
struct AffOp<language::pluseq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    if constexpr (std::is_same_v<uint64_t, A> and std::is_same_v<int64_t, B>) {
      if (static_cast<int64_t>(a + b) < 0) {
        throw std::domain_error("trying to affect negative value (lhs: " + std::to_string(a) +
                                " rhs: " + std::to_string(b) + ")");
      }
    }
    a += b;
  }
};

template <>
struct AffOp<language::minuseq_op>
{
  template <typename A, typename B>
  PUGS_INLINE void
  eval(A& a, const B& b)
  {
    if constexpr (std::is_same_v<uint64_t, A> and std::is_same_v<int64_t, B>) {
      if (static_cast<int64_t>(a - b) < 0) {
        throw std::domain_error("trying to affect negative value (lhs: " + std::to_string(a) +
                                " rhs: " + std::to_string(b) + ")");
      }
    }
    a -= b;
  }
};

struct IAffectationExecutor
{
  virtual void affect(ExecutionPolicy& exec_policy, DataVariant&& rhs) = 0;

  IAffectationExecutor(const IAffectationExecutor&) = delete;
  IAffectationExecutor(IAffectationExecutor&&)      = delete;

  IAffectationExecutor() = default;

  virtual ~IAffectationExecutor() = default;
};

template <typename OperatorT, typename ValueT, typename DataT>
class AffectationExecutor final : public IAffectationExecutor
{
 private:
  DataVariant& m_lhs;
  ASTNode& m_node;

  static inline const bool m_is_defined{[] {
    if constexpr (std::is_same_v<std::decay_t<ValueT>, bool>) {
      if constexpr (not std::is_same_v<OperatorT, language::eq_op>) {
        return false;
      }
    }
    return true;
  }()};

 public:
  AffectationExecutor(ASTNode& node, DataVariant& lhs) : m_lhs(lhs), m_node{node}
  {
    // LCOV_EXCL_START
    if constexpr (not m_is_defined) {
      throw ParseError("unexpected error: invalid operands to affectation expression", std::vector{node.begin()});
    }
    // LCOV_EXCL_STOP
  }

  PUGS_INLINE void
  affect(ExecutionPolicy&, DataVariant&& rhs)
  {
    if constexpr (m_is_defined) {
      if constexpr (not std::is_same_v<DataT, ZeroType>) {
        if constexpr (std::is_same_v<ValueT, std::string>) {
          if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
            if constexpr (std::is_same_v<std::string, DataT>) {
              m_lhs = std::get<DataT>(rhs);
            } else if constexpr (std::is_arithmetic_v<DataT>) {
              m_lhs = std::to_string(std::get<DataT>(rhs));
            } else {
              std::ostringstream os;
              os << std::get<DataT>(rhs);
              m_lhs = os.str();
            }
          } else {
            if constexpr (std::is_same_v<std::string, DataT>) {
              std::get<std::string>(m_lhs) += std::get<std::string>(rhs);
            } else if constexpr ((std::is_arithmetic_v<DataT>)and not(std::is_same_v<bool, DataT>)) {
              std::get<std::string>(m_lhs) += std::to_string(std::get<DataT>(rhs));
            } else {
              std::ostringstream os;
              os << std::boolalpha << std::get<DataT>(rhs);
              std::get<std::string>(m_lhs) += os.str();
            }
          }
        } else {
          if constexpr (std::is_same_v<OperatorT, language::eq_op>) {
            if constexpr (std::is_convertible_v<DataT, ValueT>) {
              const DataT& value = std::get<DataT>(rhs);
              if constexpr (std::is_same_v<uint64_t, ValueT> and std::is_same_v<int64_t, DataT>) {
                if (value < 0) {
                  throw std::domain_error("trying to affect negative value (" + std::to_string(value) + ")");
                }
              }
              m_lhs = static_cast<ValueT>(value);
            } else if constexpr (std::is_same_v<DataT, AggregateDataVariant>) {
              const AggregateDataVariant& v = std::get<AggregateDataVariant>(rhs);
              if constexpr (is_tiny_vector_v<ValueT>) {
                ValueT value;
                for (size_t i = 0; i < ValueT::Dimension; ++i) {
                  std::visit(
                    [&](auto&& vi) {
                      using Vi_T = std::decay_t<decltype(vi)>;
                      if constexpr (std::is_convertible_v<Vi_T, double>) {
                        value[i] = vi;
                      } else {
                        // LCOV_EXCL_START
                        throw UnexpectedError("unexpected rhs type in affectation");
                        // LCOV_EXCL_STOP
                      }
                    },
                    v[i]);
                }
                m_lhs = value;
              } else if constexpr (is_tiny_matrix_v<ValueT>) {
                ValueT value;
                for (size_t i = 0, l = 0; i < ValueT::NumberOfRows; ++i) {
                  for (size_t j = 0; j < ValueT::NumberOfColumns; ++j, ++l) {
                    std::visit(
                      [&](auto&& Aij) {
                        using Aij_T = std::decay_t<decltype(Aij)>;
                        if constexpr (std::is_convertible_v<Aij_T, double>) {
                          value(i, j) = Aij;
                        } else {
                          // LCOV_EXCL_START
                          throw UnexpectedError("unexpected rhs type in affectation");
                          // LCOV_EXCL_STOP
                        }
                      },
                      v[l]);
                  }
                }
                m_lhs = value;
              } else {
                static_assert(is_tiny_matrix_v<ValueT> or is_tiny_vector_v<ValueT>, "invalid rhs type");
              }
            } else if constexpr (std::is_same_v<TinyVector<1>, ValueT>) {
              std::visit(
                [&](auto&& v) {
                  using Vi_T = std::decay_t<decltype(v)>;
                  if constexpr (std::is_convertible_v<Vi_T, double>) {
                    m_lhs = TinyVector<1>(v);
                  } else {
                    // LCOV_EXCL_START
                    throw UnexpectedError("unexpected rhs type in affectation");
                    // LCOV_EXCL_STOP
                  }
                },
                rhs);
            } else if constexpr (std::is_same_v<TinyMatrix<1>, ValueT>) {
              std::visit(
                [&](auto&& v) {
                  using Vi_T = std::decay_t<decltype(v)>;
                  if constexpr (std::is_convertible_v<Vi_T, double>) {
                    m_lhs = TinyMatrix<1>(v);
                  } else {
                    // LCOV_EXCL_START
                    throw UnexpectedError("unexpected rhs type in affectation");
                    // LCOV_EXCL_STOP
                  }
                },
                rhs);
            } else {
              throw UnexpectedError("invalid value type");
            }
          } else {
            AffOp<OperatorT>().eval(std::get<ValueT>(m_lhs), std::get<DataT>(rhs));
          }
        }
      } else if (std::is_same_v<OperatorT, language::eq_op>) {
        m_lhs = ValueT{zero};
      } else {
        static_assert(std::is_same_v<OperatorT, language::eq_op>, "unexpected operator type");
      }
    }
  }
};

template <typename OperatorT, typename ValueT, typename DataT>
class AffectationProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_rhs_node;

  std::unique_ptr<IAffectationExecutor> m_affectation_executor;

  std::unique_ptr<IAffectationExecutor>
  _buildAffectationExecutor(ASTNode& lhs_node)
  {
    if (lhs_node.is_type<language::name>()) {
      const std::string& symbol = lhs_node.string();
      auto [i_symbol, found]    = lhs_node.m_symbol_table->find(symbol, lhs_node.begin());
      Assert(found);
      DataVariant& value = i_symbol->attributes().value();

      using AffectationExecutorT = AffectationExecutor<OperatorT, ValueT, DataT>;
      return std::make_unique<AffectationExecutorT>(lhs_node, value);
    } else {
      // LCOV_EXCL_START
      throw ParseError("unexpected error: invalid lhs", std::vector{lhs_node.begin()});
      // LCOV_EXCL_STOP
    }
  }

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    try {
      m_affectation_executor->affect(exec_policy, m_rhs_node.execute(exec_policy));
    }
    catch (std::domain_error& e) {
      throw ParseError(e.what(), m_rhs_node.begin());
    }
    return {};
  }

  AffectationProcessor(ASTNode& lhs_node, ASTNode& rhs_node)
    : m_rhs_node{rhs_node}, m_affectation_executor{this->_buildAffectationExecutor(lhs_node)}
  {}
};

class AffectationToDataVariantProcessorBase : public INodeProcessor
{
 protected:
  DataVariant* m_lhs;

 public:
  AffectationToDataVariantProcessorBase(ASTNode& lhs_node)
  {
    const std::string& symbol = lhs_node.string();
    auto [i_symbol, found]    = lhs_node.m_symbol_table->find(symbol, lhs_node.begin());
    Assert(found);

    m_lhs = &i_symbol->attributes().value();
  }

  virtual ~AffectationToDataVariantProcessorBase() = default;
};

template <typename OperatorT, typename ValueT>
class AffectationToTinyVectorFromListProcessor final : public AffectationToDataVariantProcessorBase
{
 private:
  ASTNode& m_rhs_node;

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_rhs_node.execute(exec_policy));

    static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator for list to vectors");

    ValueT v;
    for (size_t i = 0; i < v.dimension(); ++i) {
      std::visit(
        [&](auto&& child_value) {
          using T = std::decay_t<decltype(child_value)>;
          if constexpr (std::is_same_v<T, bool> or std::is_same_v<T, uint64_t> or std::is_same_v<T, int64_t> or
                        std::is_same_v<T, double>) {
            v[i] = child_value;
          } else {
            // LCOV_EXCL_START
            throw ParseError("unexpected error: unexpected right hand side type in affectation", m_rhs_node.begin());
            // LCOV_EXCL_STOP
          }
        },
        children_values[i]);
    }

    *m_lhs = v;
    return {};
  }

  AffectationToTinyVectorFromListProcessor(ASTNode& lhs_node, ASTNode& rhs_node)
    : AffectationToDataVariantProcessorBase(lhs_node), m_rhs_node{rhs_node}
  {}
};

template <typename OperatorT, typename ValueT>
class AffectationToTinyMatrixFromListProcessor final : public AffectationToDataVariantProcessorBase
{
 private:
  ASTNode& m_rhs_node;

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_rhs_node.execute(exec_policy));

    static_assert(std::is_same_v<OperatorT, language::eq_op>, "forbidden affection operator for list to vectors");

    ValueT v;
    for (size_t i = 0, l = 0; i < v.numberOfRows(); ++i) {
      for (size_t j = 0; j < v.numberOfColumns(); ++j, ++l) {
        std::visit(
          [&](auto&& child_value) {
            using T = std::decay_t<decltype(child_value)>;
            if constexpr (std::is_same_v<T, bool> or std::is_same_v<T, uint64_t> or std::is_same_v<T, int64_t> or
                          std::is_same_v<T, double>) {
              v(i, j) = child_value;
            } else {
              // LCOV_EXCL_START
              throw ParseError("unexpected error: unexpected right hand side type in affectation", m_rhs_node.begin());
              // LCOV_EXCL_STOP
            }
          },
          children_values[l]);
      }
    }

    *m_lhs = v;
    return {};
  }

  AffectationToTinyMatrixFromListProcessor(ASTNode& lhs_node, ASTNode& rhs_node)
    : AffectationToDataVariantProcessorBase(lhs_node), m_rhs_node{rhs_node}
  {}
};

template <typename ValueT>
class AffectationToTupleProcessor final : public AffectationToDataVariantProcessorBase
{
 private:
  ASTNode& m_rhs_node;

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    DataVariant value = m_rhs_node.execute(exec_policy);

    try {
      std::visit(
        [&](auto&& v) {
          using T = std::decay_t<decltype(v)>;
          if constexpr (std::is_same_v<T, ValueT>) {
            *m_lhs = std::vector{std::move(v)};
          } else if constexpr (std::is_arithmetic_v<ValueT> and std::is_convertible_v<T, ValueT>) {
            if constexpr (std::is_same_v<uint64_t, ValueT> and std::is_same_v<int64_t, T>) {
              if (v < 0) {
                throw std::domain_error("trying to affect negative value (" + std::to_string(v) + ")");
              }
            }
            *m_lhs = std::vector{std::move(static_cast<ValueT>(v))};
          } else if constexpr (std::is_same_v<std::string, ValueT>) {
            if constexpr (std::is_arithmetic_v<T>) {
              *m_lhs = std::vector{std::move(std::to_string(v))};
            } else {
              std::ostringstream os;
              os << v;
              *m_lhs = std::vector<std::string>{os.str()};
            }
          } else if constexpr (is_tiny_vector_v<ValueT> or is_tiny_matrix_v<ValueT>) {
            if constexpr (std::is_same_v<ValueT, TinyVector<1>> and std::is_arithmetic_v<T>) {
              *m_lhs = std::vector<TinyVector<1>>{TinyVector<1>{static_cast<double>(v)}};
            } else if constexpr (std::is_same_v<ValueT, TinyMatrix<1>> and std::is_arithmetic_v<T>) {
              *m_lhs = std::vector<TinyMatrix<1>>{TinyMatrix<1>{static_cast<double>(v)}};
            } else if constexpr (std::is_same_v<T, int64_t>) {
              Assert(v == 0);
              *m_lhs = std::vector<ValueT>{ValueT{zero}};
            } else {
              // LCOV_EXCL_START
              throw ParseError("unexpected error: unexpected right hand side type in affectation", m_rhs_node.begin());
              // LCOV_EXCL_STOP
            }
          } else {
            // LCOV_EXCL_START
            throw ParseError("unexpected error: unexpected right hand side type in affectation", m_rhs_node.begin());
            // LCOV_EXCL_STOP
          }
        },
        value);
    }
    catch (std::domain_error& e) {
      throw ParseError(e.what(), m_rhs_node.begin());
    }
    return {};
  }

  AffectationToTupleProcessor(ASTNode& lhs_node, ASTNode& rhs_node)
    : AffectationToDataVariantProcessorBase(lhs_node), m_rhs_node{rhs_node}
  {}
};

template <typename ValueT>
class AffectationToTupleFromListProcessor final : public AffectationToDataVariantProcessorBase
{
 private:
  ASTNode& m_rhs_node;

  void
  _copyAggregateDataVariant(const AggregateDataVariant& children_values)
  {
    std::vector<ValueT> tuple_value(children_values.size());
    for (size_t i = 0; i < children_values.size(); ++i) {
      try {
        std::visit(
          [&](auto&& child_value) {
            using T = std::decay_t<decltype(child_value)>;
            if constexpr (std::is_same_v<T, ValueT>) {
              tuple_value[i] = child_value;
            } else if constexpr (std::is_arithmetic_v<ValueT> and std::is_convertible_v<T, ValueT>) {
              if constexpr (std::is_same_v<uint64_t, ValueT> and std::is_same_v<int64_t, T>) {
                if (child_value < 0) {
                  throw std::domain_error("trying to affect negative value (" + std::to_string(child_value) + ")");
                }
              }
              tuple_value[i] = static_cast<ValueT>(child_value);
            } else if constexpr (std::is_same_v<std::string, ValueT>) {
              if constexpr (std::is_arithmetic_v<T>) {
                tuple_value[i] = std::to_string(child_value);
              } else {
                std::ostringstream os;
                os << child_value;
                tuple_value[i] = os.str();
              }
            } else if constexpr (is_tiny_vector_v<ValueT>) {
              if constexpr (std::is_same_v<T, AggregateDataVariant>) {
                ValueT& v = tuple_value[i];
                Assert(ValueT::Dimension == child_value.size());
                for (size_t j = 0; j < ValueT::Dimension; ++j) {
                  std::visit(
                    [&](auto&& vj) {
                      using Ti = std::decay_t<decltype(vj)>;
                      if constexpr (std::is_convertible_v<Ti, typename ValueT::data_type>) {
                        v[j] = vj;
                      } else {
                        // LCOV_EXCL_START
                        throw ParseError("unexpected error: unexpected right hand side type in affectation",
                                         m_rhs_node.children[i]->begin());
                        // LCOV_EXCL_STOP
                      }
                    },
                    child_value[j]);
                }
              } else if constexpr (std::is_arithmetic_v<T>) {
                if constexpr (std::is_same_v<ValueT, TinyVector<1>>) {
                  tuple_value[i][0] = child_value;
                } else {
                  // in this case a 0 is given
                  Assert(child_value == 0);
                  tuple_value[i] = ZeroType{};
                }
              } else {
                // LCOV_EXCL_START
                throw ParseError("unexpected error: unexpected right hand side type in affectation",
                                 m_rhs_node.children[i]->begin());
                // LCOV_EXCL_STOP
              }
            } else if constexpr (is_tiny_matrix_v<ValueT>) {
              if constexpr (std::is_same_v<T, AggregateDataVariant>) {
                ValueT& A = tuple_value[i];
                Assert(A.numberOfRows() * A.numberOfColumns() == child_value.size());
                for (size_t j = 0, l = 0; j < A.numberOfRows(); ++j) {
                  for (size_t k = 0; k < A.numberOfColumns(); ++k, ++l) {
                    std::visit(
                      [&](auto&& Ajk) {
                        using Ti = std::decay_t<decltype(Ajk)>;
                        if constexpr (std::is_convertible_v<Ti, typename ValueT::data_type>) {
                          A(j, k) = Ajk;
                        } else {
                          // LCOV_EXCL_START
                          throw ParseError("unexpected error: unexpected right hand side type in affectation",
                                           m_rhs_node.children[i]->begin());
                          // LCOV_EXCL_STOP
                        }
                      },
                      child_value[l]);
                  }
                }
              } else if constexpr (std::is_arithmetic_v<T>) {
                if constexpr (std::is_same_v<ValueT, TinyMatrix<1>>) {
                  tuple_value[i](0, 0) = child_value;
                } else {
                  // in this case a 0 is given
                  Assert(child_value == 0);
                  tuple_value[i] = ZeroType{};
                }
              } else {
                // LCOV_EXCL_START
                throw ParseError("unexpected error: unexpected right hand side type in affectation",
                                 m_rhs_node.children[i]->begin());
                // LCOV_EXCL_STOP
              }
            } else {
              // LCOV_EXCL_START
              throw ParseError("unexpected error: unexpected right hand side type in affectation",
                               m_rhs_node.children[i]->begin());
              // LCOV_EXCL_STOP
            }
          },
          children_values[i]);
      }
      catch (std::domain_error& e) {
        throw ParseError(e.what(), m_rhs_node.children[i]->begin());
      }
    }
    *m_lhs = std::move(tuple_value);
  }

  template <typename DataType>
  void
  _copyVector(const std::vector<DataType>& values)
  {
    std::vector<ValueT> v(values.size());
    if constexpr (std::is_same_v<ValueT, DataType>) {
      for (size_t i = 0; i < values.size(); ++i) {
        v[i] = values[i];
      }
    } else if constexpr (std::is_arithmetic_v<ValueT> and std::is_convertible_v<DataType, ValueT>) {
      for (size_t i = 0; i < values.size(); ++i) {
        v[i] = static_cast<DataType>(values[i]);
      }
    } else if constexpr (std::is_same_v<ValueT, std::string>) {
      if constexpr (std::is_arithmetic_v<DataType>) {
        for (size_t i = 0; i < values.size(); ++i) {
          v[i] = std::to_string(values[i]);
        }
      } else {
        for (size_t i = 0; i < values.size(); ++i) {
          std::ostringstream sout;
          sout << values[i];
          v[i] = sout.str();
        }
      }
    } else {
      // LCOV_EXCL_START
      throw ParseError("unexpected error: unexpected right hand side type in tuple affectation", m_rhs_node.begin());
      // LCOV_EXCL_STOP
    }

    *m_lhs = std::move(v);
  }

 public:
  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    std::visit(
      [&](auto&& value_list) {
        using ValueListT = std::decay_t<decltype(value_list)>;
        if constexpr (std::is_same_v<AggregateDataVariant, ValueListT>) {
          this->_copyAggregateDataVariant(value_list);
        } else if constexpr (is_std_vector_v<ValueListT>) {
          this->_copyVector(value_list);
        } else {
          // LCOV_EXCL_START
          throw ParseError("unexpected error: invalid lhs (expecting list or tuple)", std::vector{m_rhs_node.begin()});
          // LCOV_EXCL_STOP
        }
      },
      m_rhs_node.execute(exec_policy));

    return {};
  }

  AffectationToTupleFromListProcessor(ASTNode& lhs_node, ASTNode& rhs_node)
    : AffectationToDataVariantProcessorBase(lhs_node), m_rhs_node{rhs_node}
  {}
};

template <typename ValueT>
class AffectationFromZeroProcessor final : public AffectationToDataVariantProcessorBase
{
 public:
  DataVariant
  execute(ExecutionPolicy&)
  {
    *m_lhs = ValueT{zero};
    return {};
  }

  AffectationFromZeroProcessor(ASTNode& lhs_node) : AffectationToDataVariantProcessorBase(lhs_node) {}
};

template <typename OperatorT>
class ListAffectationProcessor final : public INodeProcessor
{
 private:
  ASTNode& m_node;

  std::vector<std::unique_ptr<IAffectationExecutor>> m_affectation_executor_list;

 public:
  template <typename ValueT, typename DataT>
  void
  add(ASTNode& lhs_node)
  {
    using AffectationExecutorT = AffectationExecutor<OperatorT, ValueT, DataT>;

    if (lhs_node.is_type<language::name>()) {
      const std::string& symbol = lhs_node.string();
      auto [i_symbol, found]    = m_node.m_symbol_table->find(symbol, m_node.children[0]->end());
      Assert(found);
      DataVariant& value = i_symbol->attributes().value();

      if (not std::holds_alternative<ValueT>(value)) {
        value = ValueT{};
      }

      m_affectation_executor_list.emplace_back(std::make_unique<AffectationExecutorT>(m_node, value));
    } else {
      // LCOV_EXCL_START
      throw ParseError("unexpected error: invalid left hand side", std::vector{lhs_node.begin()});
      // LCOV_EXCL_STOP
    }
  }

  DataVariant
  execute(ExecutionPolicy& exec_policy)
  {
    AggregateDataVariant children_values = std::get<AggregateDataVariant>(m_node.children[1]->execute(exec_policy));
    Assert(m_affectation_executor_list.size() == children_values.size());
    for (size_t i = 0; i < m_affectation_executor_list.size(); ++i) {
      try {
        m_affectation_executor_list[i]->affect(exec_policy, std::move(children_values[i]));
      }
      catch (std::domain_error& e) {
        throw ParseError(e.what(), m_node.children[1]->children[i]->begin());
      }
    }
    return {};
  }

  ListAffectationProcessor(ASTNode& node) : m_node{node} {}
};

#endif   // AFFECTATION_PROCESSOR_HPP
