#include <language/utils/EmbeddedIDiscreteFunctionOperators.hpp>

#include <language/node_processor/BinaryExpressionProcessor.hpp>
#include <language/node_processor/UnaryExpressionProcessor.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionP0Vector.hpp>
#include <scheme/IDiscreteFunction.hpp>
#include <utils/Exceptions.hpp>

template <typename T>
PUGS_INLINE std::string
operand_type_name(const T& t)
{
  if constexpr (is_shared_ptr_v<T>) {
    Assert(t.use_count() > 0);
    return operand_type_name(*t);
  } else if constexpr (std::is_base_of_v<IDiscreteFunction, std::decay_t<T>>) {
    return "Vh(" + name(t.descriptor().type()) + ':' + dataTypeName(t.dataType()) + ')';
  } else {
    return dataTypeName(ast_node_data_type_from<T>);
  }
}

PUGS_INLINE
bool
isSameDiscretization(const IDiscreteFunction& f, const IDiscreteFunction& g)
{
  if ((f.dataType() == g.dataType()) and (f.descriptor().type() == g.descriptor().type())) {
    switch (f.dataType()) {
    case ASTNodeDataType::double_t: {
      return true;
    }
    case ASTNodeDataType::vector_t: {
      return f.dataType().dimension() == g.dataType().dimension();
    }
    case ASTNodeDataType::matrix_t: {
      return (f.dataType().nbRows() == g.dataType().nbRows()) and
             (f.dataType().nbColumns() == g.dataType().nbColumns());
    }
    default: {
      throw UnexpectedError("invalid data type " + operand_type_name(f));
    }
    }
  } else {
    return false;
  }
}

PUGS_INLINE
bool
isSameDiscretization(const std::shared_ptr<const IDiscreteFunction>& f,
                     const std::shared_ptr<const IDiscreteFunction>& g)
{
  return isSameDiscretization(*f, *g);
}

template <typename LHS_T, typename RHS_T>
PUGS_INLINE std::string
invalid_operands(const LHS_T& f, const RHS_T& g)
{
  std::ostringstream os;
  os << "undefined binary operator\n";
  os << "note: incompatible operand types " << operand_type_name(f) << " and " << operand_type_name(g);
  return os.str();
}

// unary operators
template <typename UnaryOperatorT, typename DiscreteFunctionT>
std::shared_ptr<const IDiscreteFunction>
applyUnaryOperation(const DiscreteFunctionT& f)
{
  return std::make_shared<decltype(UnaryOp<UnaryOperatorT>{}.eval(f))>(UnaryOp<UnaryOperatorT>{}.eval(f));
}

template <typename UnaryOperatorT, size_t Dimension>
std::shared_ptr<const IDiscreteFunction>
applyUnaryOperation(const std::shared_ptr<const IDiscreteFunction>& f)
{
  switch (f->descriptor().type()) {
  case DiscreteFunctionType::P0: {
    switch (f->dataType()) {
    case ASTNodeDataType::double_t: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f);
      return applyUnaryOperation<UnaryOperatorT>(fh);
    }
    case ASTNodeDataType::vector_t: {
      switch (f->dataType().dimension()) {
      case 1: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f);
        return applyUnaryOperation<UnaryOperatorT>(fh);
      }
      case 2: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f);
        return applyUnaryOperation<UnaryOperatorT>(fh);
      }
      case 3: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f);
        return applyUnaryOperation<UnaryOperatorT>(fh);
      }
      default: {
        throw UnexpectedError("invalid operand type " + operand_type_name(f));
      }
      }
    }
    case ASTNodeDataType::matrix_t: {
      Assert(f->dataType().nbRows() == f->dataType().nbColumns());
      switch (f->dataType().nbRows()) {
      case 1: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f);
        return applyUnaryOperation<UnaryOperatorT>(fh);
      }
      case 2: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f);
        return applyUnaryOperation<UnaryOperatorT>(fh);
      }
      case 3: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f);
        return applyUnaryOperation<UnaryOperatorT>(fh);
      }
      default: {
        throw UnexpectedError("invalid operand type " + operand_type_name(f));
      }
      }
    }
    default: {
      throw UnexpectedError("invalid operand type " + operand_type_name(f));
    }
    }
    break;
  }
  case DiscreteFunctionType::P0Vector: {
    switch (f->dataType()) {
    case ASTNodeDataType::double_t: {
      auto fh = dynamic_cast<const DiscreteFunctionP0Vector<Dimension, double>&>(*f);
      return applyUnaryOperation<UnaryOperatorT>(fh);
    }
    default: {
      throw UnexpectedError("invalid operand type " + operand_type_name(f));
    }
    }
    break;
  }
  default: {
    throw UnexpectedError("invalid operand type " + operand_type_name(f));
  }
  }
}

template <typename UnaryOperatorT>
std::shared_ptr<const IDiscreteFunction>
applyUnaryOperation(const std::shared_ptr<const IDiscreteFunction>& f)
{
  switch (f->mesh()->dimension()) {
  case 1: {
    return applyUnaryOperation<UnaryOperatorT, 1>(f);
  }
  case 2: {
    return applyUnaryOperation<UnaryOperatorT, 2>(f);
  }
  case 3: {
    return applyUnaryOperation<UnaryOperatorT, 3>(f);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f)
{
  return applyUnaryOperation<language::unary_minus>(f);
}

// binary operators

template <typename BinOperatorT, typename DiscreteFunctionT>
std::shared_ptr<const IDiscreteFunction>
innerCompositionLaw(const DiscreteFunctionT& lhs, const DiscreteFunctionT& rhs)
{
  Assert(lhs.mesh() == rhs.mesh());
  using data_type = typename DiscreteFunctionT::data_type;
  if constexpr ((std::is_same_v<language::multiply_op, BinOperatorT> and is_tiny_vector_v<data_type>) or
                (std::is_same_v<language::divide_op, BinOperatorT> and not std::is_arithmetic_v<data_type>)) {
    throw NormalError(invalid_operands(lhs, rhs));
  } else {
    return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs));
  }
}

template <typename BinOperatorT, size_t Dimension>
std::shared_ptr<const IDiscreteFunction>
innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f,
                    const std::shared_ptr<const IDiscreteFunction>& g)
{
  Assert(f->mesh() == g->mesh());
  Assert(isSameDiscretization(f, g));

  switch (f->dataType()) {
  case ASTNodeDataType::double_t: {
    if (f->descriptor().type() == DiscreteFunctionType::P0) {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f);
      auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g);

      return innerCompositionLaw<BinOperatorT>(fh, gh);

    } else if (f->descriptor().type() == DiscreteFunctionType::P0Vector) {
      if constexpr (std::is_same_v<BinOperatorT, language::plus_op> or
                    std::is_same_v<BinOperatorT, language::minus_op>) {
        auto fh = dynamic_cast<const DiscreteFunctionP0Vector<Dimension, double>&>(*f);
        auto gh = dynamic_cast<const DiscreteFunctionP0Vector<Dimension, double>&>(*g);

        if (fh.size() != gh.size()) {
          throw NormalError(operand_type_name(f) + " spaces have different sizes");
        }

        return innerCompositionLaw<BinOperatorT>(fh, gh);
      } else {
        throw NormalError(invalid_operands(f, g));
      }
    } else {
      throw UnexpectedError(invalid_operands(f, g));
    }
  }
  case ASTNodeDataType::vector_t: {
    Assert(f->descriptor().type() == DiscreteFunctionType::P0);
    switch (f->dataType().dimension()) {
    case 1: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f);
      auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g);

      return innerCompositionLaw<BinOperatorT>(fh, gh);
    }
    case 2: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f);
      auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g);

      return innerCompositionLaw<BinOperatorT>(fh, gh);
    }
    case 3: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f);
      auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g);

      return innerCompositionLaw<BinOperatorT>(fh, gh);
    }
    default: {
      throw NormalError(invalid_operands(f, g));
    }
    }
  }
  case ASTNodeDataType::matrix_t: {
    Assert(f->descriptor().type() == DiscreteFunctionType::P0);
    Assert(f->dataType().nbRows() == f->dataType().nbColumns());
    switch (f->dataType().nbRows()) {
    case 1: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f);
      auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g);

      return innerCompositionLaw<BinOperatorT>(fh, gh);
    }
    case 2: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f);
      auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g);

      return innerCompositionLaw<BinOperatorT>(fh, gh);
    }
    case 3: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f);
      auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g);

      return innerCompositionLaw<BinOperatorT>(fh, gh);
    }
    default: {
      throw UnexpectedError("invalid data type " + operand_type_name(f));
    }
    }
  }
  default: {
    throw UnexpectedError("invalid data type " + operand_type_name(f));
  }
  }
}

template <typename BinOperatorT>
std::shared_ptr<const IDiscreteFunction>
innerCompositionLaw(const std::shared_ptr<const IDiscreteFunction>& f,
                    const std::shared_ptr<const IDiscreteFunction>& g)
{
  if (f->mesh() != g->mesh()) {
    throw NormalError("discrete functions defined on different meshes");
  }
  if (not isSameDiscretization(f, g)) {
    throw NormalError(invalid_operands(f, g));
  }

  switch (f->mesh()->dimension()) {
  case 1: {
    return innerCompositionLaw<BinOperatorT, 1>(f, g);
  }
  case 2: {
    return innerCompositionLaw<BinOperatorT, 2>(f, g);
  }
  case 3: {
    return innerCompositionLaw<BinOperatorT, 3>(f, g);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

template <typename BinOperatorT, typename LeftDiscreteFunctionT, typename RightDiscreteFunctionT>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperation(const LeftDiscreteFunctionT& lhs, const RightDiscreteFunctionT& rhs)
{
  Assert(lhs.mesh() == rhs.mesh());

  static_assert(not std::is_same_v<LeftDiscreteFunctionT, RightDiscreteFunctionT>,
                "use innerCompositionLaw when data types are the same");

  return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(lhs, rhs))>(BinOp<BinOperatorT>{}.eval(lhs, rhs));
}

template <typename BinOperatorT, size_t Dimension, typename DiscreteFunctionT>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperation(const DiscreteFunctionT& fh, const std::shared_ptr<const IDiscreteFunction>& g)
{
  Assert(fh.mesh() == g->mesh());
  Assert(not isSameDiscretization(fh, *g));
  using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>;

  switch (g->dataType()) {
  case ASTNodeDataType::double_t: {
    if constexpr (not std::is_same_v<lhs_data_type, double>) {
      if constexpr (not is_tiny_matrix_v<lhs_data_type>) {
        auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*g);

        return applyBinaryOperation<BinOperatorT>(fh, gh);
      } else {
        throw NormalError(invalid_operands(fh, g));
      }
    } else if constexpr (std::is_same_v<BinOperatorT, language::multiply_op> and
                         std::is_same_v<DiscreteFunctionT, DiscreteFunctionP0<Dimension, double>>) {
      if (g->descriptor().type() == DiscreteFunctionType::P0Vector) {
        auto gh = dynamic_cast<const DiscreteFunctionP0Vector<Dimension, double>&>(*g);
        return applyBinaryOperation<BinOperatorT>(fh, gh);
      } else {
        throw NormalError(invalid_operands(fh, g));
      }
    } else {
      throw UnexpectedError("should have called innerCompositionLaw");
    }
  }
  case ASTNodeDataType::vector_t: {
    if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) {
      switch (g->dataType().dimension()) {
      case 1: {
        if constexpr (not is_tiny_vector_v<lhs_data_type> and
                      (std::is_same_v<lhs_data_type, TinyMatrix<1>> or std::is_same_v<lhs_data_type, double>)) {
          auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*g);

          return applyBinaryOperation<BinOperatorT>(fh, gh);
        } else {
          throw NormalError(invalid_operands(fh, g));
        }
      }
      case 2: {
        if constexpr (not is_tiny_vector_v<lhs_data_type> and
                      (std::is_same_v<lhs_data_type, TinyMatrix<2>> or std::is_same_v<lhs_data_type, double>)) {
          auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*g);

          return applyBinaryOperation<BinOperatorT>(fh, gh);
        } else {
          throw NormalError(invalid_operands(fh, g));
        }
      }
      case 3: {
        if constexpr (not is_tiny_vector_v<lhs_data_type> and
                      (std::is_same_v<lhs_data_type, TinyMatrix<3>> or std::is_same_v<lhs_data_type, double>)) {
          auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*g);

          return applyBinaryOperation<BinOperatorT>(fh, gh);
        } else {
          throw NormalError(invalid_operands(fh, g));
        }
      }
      default: {
        throw UnexpectedError("invalid rhs data type " + operand_type_name(g));
      }
      }
    } else {
      throw NormalError(invalid_operands(fh, g));
    }
  }
  case ASTNodeDataType::matrix_t: {
    Assert(g->dataType().nbRows() == g->dataType().nbColumns());
    if constexpr (std::is_same_v<lhs_data_type, double> and std::is_same_v<language::multiply_op, BinOperatorT>) {
      switch (g->dataType().nbRows()) {
      case 1: {
        auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*g);

        return applyBinaryOperation<BinOperatorT>(fh, gh);
      }
      case 2: {
        auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*g);

        return applyBinaryOperation<BinOperatorT>(fh, gh);
      }
      case 3: {
        auto gh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*g);

        return applyBinaryOperation<BinOperatorT>(fh, gh);
      }
      default: {
        throw UnexpectedError("invalid rhs data type " + operand_type_name(g));
      }
      }
    } else {
      throw NormalError(invalid_operands(fh, g));
    }
  }
  default: {
    throw UnexpectedError("invalid rhs data type " + operand_type_name(g));
  }
  }
}

template <typename BinOperatorT, size_t Dimension>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f,
                     const std::shared_ptr<const IDiscreteFunction>& g)
{
  Assert(f->mesh() == g->mesh());
  Assert(not isSameDiscretization(f, g));

  switch (f->dataType()) {
  case ASTNodeDataType::double_t: {
    auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f);

    return applyBinaryOperation<BinOperatorT, Dimension>(fh, g);
  }
  case ASTNodeDataType::matrix_t: {
    Assert(f->dataType().nbRows() == f->dataType().nbColumns());
    switch (f->dataType().nbRows()) {
    case 1: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f);

      return applyBinaryOperation<BinOperatorT, Dimension>(fh, g);
    }
    case 2: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f);

      return applyBinaryOperation<BinOperatorT, Dimension>(fh, g);
    }
    case 3: {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f);

      return applyBinaryOperation<BinOperatorT, Dimension>(fh, g);
    }
    default: {
      throw UnexpectedError("invalid lhs data type " + operand_type_name(f));
    }
    }
  }
  default: {
    throw NormalError(invalid_operands(f, g));
  }
  }
}

template <typename BinOperatorT>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperation(const std::shared_ptr<const IDiscreteFunction>& f,
                     const std::shared_ptr<const IDiscreteFunction>& g)
{
  if (f->mesh() != g->mesh()) {
    throw NormalError("functions defined on different meshes");
  }

  Assert(not isSameDiscretization(f, g), "should call inner composition instead");

  switch (f->mesh()->dimension()) {
  case 1: {
    return applyBinaryOperation<BinOperatorT, 1>(f, g);
  }
  case 2: {
    return applyBinaryOperation<BinOperatorT, 2>(f, g);
  }
  case 3: {
    return applyBinaryOperation<BinOperatorT, 3>(f, g);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  if (isSameDiscretization(f, g)) {
    return innerCompositionLaw<language::plus_op>(f, g);
  } else {
    throw NormalError(invalid_operands(f, g));
  }
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  if (isSameDiscretization(f, g)) {
    return innerCompositionLaw<language::minus_op>(f, g);
  } else {
    throw NormalError(invalid_operands(f, g));
  }
}

std::shared_ptr<const IDiscreteFunction>
operator*(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  if (isSameDiscretization(f, g)) {
    return innerCompositionLaw<language::multiply_op>(f, g);
  } else {
    return applyBinaryOperation<language::multiply_op>(f, g);
  }
}

std::shared_ptr<const IDiscreteFunction>
operator/(const std::shared_ptr<const IDiscreteFunction>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  if (isSameDiscretization(f, g)) {
    return innerCompositionLaw<language::divide_op>(f, g);
  } else {
    return applyBinaryOperation<language::divide_op>(f, g);
  }
}

template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperationWithLeftConstant(const DataType& a, const DiscreteFunctionT& f)
{
  using lhs_data_type = std::decay_t<DataType>;
  using rhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>;

  if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) {
    if constexpr (std::is_same_v<lhs_data_type, double>) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f));
    } else if constexpr (is_tiny_matrix_v<lhs_data_type> and
                         (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f));
    } else {
      throw NormalError(invalid_operands(a, f));
    }
  } else if constexpr (std::is_same_v<language::plus_op, BinOperatorT> or
                       std::is_same_v<language::minus_op, BinOperatorT>) {
    if constexpr (std::is_same_v<lhs_data_type, double> and std::is_arithmetic_v<rhs_data_type>) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f));
    } else if constexpr (std::is_same_v<lhs_data_type, rhs_data_type>) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f));
    } else {
      throw NormalError(invalid_operands(a, f));
    }
  } else if constexpr (std::is_same_v<language::divide_op, BinOperatorT>) {
    if constexpr (std::is_same_v<lhs_data_type, double> and std::is_arithmetic_v<rhs_data_type>) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f));
    } else {
      throw NormalError(invalid_operands(a, f));
    }
  } else {
    throw NormalError(invalid_operands(a, f));
  }
}

template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperationToVectorWithLeftConstant(const DataType& a, const DiscreteFunctionT& f)
{
  using lhs_data_type = std::decay_t<DataType>;
  using rhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>;

  if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) {
    if constexpr (std::is_same_v<lhs_data_type, double>) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f));
    } else if constexpr (is_tiny_matrix_v<lhs_data_type> and
                         (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(a, f))>(BinOp<BinOperatorT>{}.eval(a, f));
    } else {
      throw NormalError(invalid_operands(a, f));
    }
  } else {
    throw NormalError(invalid_operands(a, f));
  }
}

template <typename BinOperatorT, size_t Dimension, typename DataType>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f)
{
  switch (f->dataType()) {
  case ASTNodeDataType::bool_t:
  case ASTNodeDataType::unsigned_int_t:
  case ASTNodeDataType::int_t:
  case ASTNodeDataType::double_t: {
    if (f->descriptor().type() == DiscreteFunctionType::P0) {
      auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f);
      return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
    } else if (f->descriptor().type() == DiscreteFunctionType::P0Vector) {
      auto fh = dynamic_cast<const DiscreteFunctionP0Vector<Dimension, double>&>(*f);
      return applyBinaryOperationToVectorWithLeftConstant<BinOperatorT>(a, fh);
    } else {
      throw NormalError(invalid_operands(a, f));
    }
  }
  case ASTNodeDataType::vector_t: {
    if constexpr (is_tiny_matrix_v<DataType>) {
      switch (f->dataType().dimension()) {
      case 1: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f);
          return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
        } else {
          throw NormalError(invalid_operands(a, f));
        }
      }
      case 2: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f);
          return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
        } else {
          throw NormalError(invalid_operands(a, f));
        }
      }
      case 3: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f);
          return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
        } else {
          throw NormalError(invalid_operands(a, f));
        }
      }
      default: {
        throw UnexpectedError("invalid lhs data type " + operand_type_name(f));
      }
      }
    } else {
      switch (f->dataType().dimension()) {
      case 1: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<1>>&>(*f);
        return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
      }
      case 2: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<2>>&>(*f);
        return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
      }
      case 3: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyVector<3>>&>(*f);
        return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
      }
      default: {
        throw UnexpectedError("invalid lhs data type " + operand_type_name(f));
      }
      }
    }
  }
  case ASTNodeDataType::matrix_t: {
    Assert(f->dataType().nbRows() == f->dataType().nbColumns());
    if constexpr (is_tiny_matrix_v<DataType>) {
      switch (f->dataType().nbRows()) {
      case 1: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f);
          return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
        } else {
          throw NormalError(invalid_operands(a, f));
        }
      }
      case 2: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f);
          return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
        } else {
          throw NormalError(invalid_operands(a, f));
        }
      }
      case 3: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f);
          return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
        } else {
          throw NormalError(invalid_operands(a, f));
        }
      }
      default: {
        throw UnexpectedError("invalid lhs data type " + operand_type_name(f));
      }
      }
    } else {
      switch (f->dataType().nbRows()) {
      case 1: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f);
        return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
      }
      case 2: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f);
        return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
      }
      case 3: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f);
        return applyBinaryOperationWithLeftConstant<BinOperatorT>(a, fh);
      }
      default: {
        throw UnexpectedError("invalid lhs data type " + operand_type_name(f));
      }
      }
    }
  }
  default: {
    throw NormalError(invalid_operands(a, f));
  }
  }
}

template <typename BinOperatorT, typename DataType>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperationWithLeftConstant(const DataType& a, const std::shared_ptr<const IDiscreteFunction>& f)
{
  switch (f->mesh()->dimension()) {
  case 1: {
    return applyBinaryOperationWithLeftConstant<BinOperatorT, 1>(a, f);
  }
  case 2: {
    return applyBinaryOperationWithLeftConstant<BinOperatorT, 2>(a, f);
  }
  case 3: {
    return applyBinaryOperationWithLeftConstant<BinOperatorT, 3>(a, f);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

template <typename BinOperatorT, typename DataType, typename DiscreteFunctionT>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperationWithRightConstant(const DiscreteFunctionT& f, const DataType& a)
{
  Assert(f.descriptor().type() != DiscreteFunctionType::P0);

  using lhs_data_type = std::decay_t<typename DiscreteFunctionT::data_type>;
  using rhs_data_type = std::decay_t<DataType>;

  if constexpr (std::is_same_v<language::multiply_op, BinOperatorT>) {
    if constexpr (is_tiny_matrix_v<lhs_data_type> and is_tiny_matrix_v<rhs_data_type>) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a));
    } else if constexpr (std::is_same_v<lhs_data_type, double> and
                         (is_tiny_matrix_v<rhs_data_type> or is_tiny_vector_v<rhs_data_type>)) {
      return std::make_shared<decltype(BinOp<BinOperatorT>{}.eval(f, a))>(BinOp<BinOperatorT>{}.eval(f, a));
    } else {
      throw NormalError(invalid_operands(f, a));
    }
  } else {
    throw NormalError(invalid_operands(f, a));
  }
}

template <typename BinOperatorT, size_t Dimension, typename DataType>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a)
{
  if (f->descriptor().type() != DiscreteFunctionType::P0) {
    throw NormalError(invalid_operands(f, a));
  }

  switch (f->dataType()) {
  case ASTNodeDataType::bool_t:
  case ASTNodeDataType::unsigned_int_t:
  case ASTNodeDataType::int_t:
  case ASTNodeDataType::double_t: {
    auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, double>&>(*f);
    return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a);
  }
  case ASTNodeDataType::matrix_t: {
    Assert(f->dataType().nbRows() == f->dataType().nbColumns());
    if constexpr (is_tiny_matrix_v<DataType>) {
      switch (f->dataType().nbRows()) {
      case 1: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<1>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f);
          return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a);
        } else {
          throw NormalError(invalid_operands(f, a));
        }
      }
      case 2: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<2>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f);
          return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a);
        } else {
          throw NormalError(invalid_operands(f, a));
        }
      }
      case 3: {
        if constexpr (std::is_same_v<DataType, TinyMatrix<3>>) {
          auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f);
          return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a);
        } else {
          throw NormalError(invalid_operands(f, a));
        }
      }
      default: {
        throw UnexpectedError("invalid lhs data type " + operand_type_name(f));
      }
      }
    } else {
      switch (f->dataType().nbRows()) {
      case 1: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<1>>&>(*f);
        return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a);
      }
      case 2: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<2>>&>(*f);
        return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a);
      }
      case 3: {
        auto fh = dynamic_cast<const DiscreteFunctionP0<Dimension, TinyMatrix<3>>&>(*f);
        return applyBinaryOperationWithRightConstant<BinOperatorT>(fh, a);
      }
      default: {
        throw UnexpectedError("invalid lhs data type " + operand_type_name(f));
      }
      }
    }
  }
  default: {
    throw NormalError(invalid_operands(f, a));
  }
  }
}

template <typename BinOperatorT, typename DataType>
std::shared_ptr<const IDiscreteFunction>
applyBinaryOperationWithRightConstant(const std::shared_ptr<const IDiscreteFunction>& f, const DataType& a)
{
  switch (f->mesh()->dimension()) {
  case 1: {
    return applyBinaryOperationWithRightConstant<BinOperatorT, 1>(f, a);
  }
  case 2: {
    return applyBinaryOperationWithRightConstant<BinOperatorT, 2>(f, a);
  }
  case 3: {
    return applyBinaryOperationWithRightConstant<BinOperatorT, 3>(f, a);
  }
  default: {
    throw UnexpectedError("invalid mesh dimension");
  }
  }
}

std::shared_ptr<const IDiscreteFunction>
operator+(const double& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const double& g)
{
  return applyBinaryOperationWithRightConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const TinyVector<1>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const TinyVector<2>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const TinyVector<3>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const TinyMatrix<1>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const TinyMatrix<2>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const TinyMatrix<3>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const TinyVector<1>& g)
{
  return applyBinaryOperationWithRightConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const TinyVector<2>& g)
{
  return applyBinaryOperationWithRightConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const TinyVector<3>& g)
{
  return applyBinaryOperationWithRightConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const TinyMatrix<1>& g)
{
  return applyBinaryOperationWithRightConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const TinyMatrix<2>& g)
{
  return applyBinaryOperationWithRightConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator+(const std::shared_ptr<const IDiscreteFunction>& f, const TinyMatrix<3>& g)
{
  return applyBinaryOperationWithRightConstant<language::plus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const double& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const double& g)
{
  return applyBinaryOperationWithRightConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const TinyVector<1>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const TinyVector<2>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const TinyVector<3>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const TinyMatrix<1>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const TinyMatrix<2>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const TinyMatrix<3>& f, const std::shared_ptr<const IDiscreteFunction>& g)
{
  return applyBinaryOperationWithLeftConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const TinyVector<1>& g)
{
  return applyBinaryOperationWithRightConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const TinyVector<2>& g)
{
  return applyBinaryOperationWithRightConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const TinyVector<3>& g)
{
  return applyBinaryOperationWithRightConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const TinyMatrix<1>& g)
{
  return applyBinaryOperationWithRightConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const TinyMatrix<2>& g)
{
  return applyBinaryOperationWithRightConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator-(const std::shared_ptr<const IDiscreteFunction>& f, const TinyMatrix<3>& g)
{
  return applyBinaryOperationWithRightConstant<language::minus_op>(f, g);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const double& a, const std::shared_ptr<const IDiscreteFunction>& f)
{
  return applyBinaryOperationWithLeftConstant<language::multiply_op>(a, f);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const TinyMatrix<1>& A, const std::shared_ptr<const IDiscreteFunction>& B)
{
  return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const TinyMatrix<2>& A, const std::shared_ptr<const IDiscreteFunction>& B)
{
  return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const TinyMatrix<3>& A, const std::shared_ptr<const IDiscreteFunction>& B)
{
  return applyBinaryOperationWithLeftConstant<language::multiply_op>(A, B);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<1>& u)
{
  return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<2>& u)
{
  return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyVector<3>& u)
{
  return applyBinaryOperationWithRightConstant<language::multiply_op>(a, u);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<1>& A)
{
  return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<2>& A)
{
  return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A);
}

std::shared_ptr<const IDiscreteFunction>
operator*(const std::shared_ptr<const IDiscreteFunction>& a, const TinyMatrix<3>& A)
{
  return applyBinaryOperationWithRightConstant<language::multiply_op>(a, A);
}

std::shared_ptr<const IDiscreteFunction>
operator/(const double& a, const std::shared_ptr<const IDiscreteFunction>& f)
{
  return applyBinaryOperationWithLeftConstant<language::divide_op>(a, f);
}
