#include <language/utils/ASTNodeNaturalConversionChecker.hpp>

#include <language/PEGGrammar.hpp>
#include <language/utils/ParseError.hpp>
#include <utils/Exceptions.hpp>

template <typename RToR1Conversion>
void
ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalTypeConversion(
  const ASTNode& node,
  const ASTNodeDataType& data_type,
  const ASTNodeDataType& target_data_type) const
{
  if (not isNaturalConversion(data_type, target_data_type)) {
    if constexpr (std::is_same_v<RToR1ConversionStrategy, AllowRToR1Conversion>) {
      if (((target_data_type == ASTNodeDataType::vector_t) and (target_data_type.dimension() == 1)) or
          ((target_data_type == ASTNodeDataType::matrix_t) and (target_data_type.nbRows() == 1) and
           (target_data_type.nbColumns() == 1))) {
        if (isNaturalConversion(data_type, ASTNodeDataType::build<ASTNodeDataType::double_t>())) {
          return;
        }
      }
    }
    std::ostringstream error_message;
    error_message << "invalid implicit conversion: ";
    error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type)
                  << rang::fg::reset;

    if ((data_type == ASTNodeDataType::undefined_t) or (target_data_type == ASTNodeDataType::undefined_t)) {
      // LCOV_EXCL_START
      throw UnexpectedError(error_message.str());
      // LCOV_EXCL_STOP
    } else {
      throw ParseError(error_message.str(), node.begin());
    }
  }
}

template <typename RToR1Conversion>
void
ASTNodeNaturalConversionChecker<RToR1Conversion>::_checkIsNaturalExpressionConversion(
  const ASTNode& node,
  const ASTNodeDataType& data_type,
  const ASTNodeDataType& target_data_type) const
{
  if (target_data_type == ASTNodeDataType::typename_t) {
    this->_checkIsNaturalExpressionConversion(node, data_type, target_data_type.contentType());
  } else if (target_data_type == ASTNodeDataType::vector_t) {
    switch (data_type) {
    case ASTNodeDataType::list_t: {
      const auto& content_type_list = data_type.contentTypeList();
      if (content_type_list.size() != target_data_type.dimension()) {
        std::ostringstream os;
        os << "incompatible dimensions in affectation: expecting " << target_data_type.dimension() << ", but provided "
           << content_type_list.size();
        throw ParseError(os.str(), std::vector{node.begin()});
      }

      Assert(content_type_list.size() == node.children.size());
      for (size_t i = 0; i < content_type_list.size(); ++i) {
        const auto& child_type = *content_type_list[i];
        const auto& child_node = *node.children[i];
        Assert(child_type == child_node.m_data_type);
        this->_checkIsNaturalExpressionConversion(child_node, child_type,
                                                  ASTNodeDataType::build<ASTNodeDataType::double_t>());
      }

      break;
    }
    case ASTNodeDataType::vector_t: {
      if (data_type.dimension() != target_data_type.dimension()) {
        std::ostringstream error_message;
        error_message << "invalid implicit conversion: ";
        error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type)
                      << rang::fg::reset;
        throw ParseError(error_message.str(), std::vector{node.begin()});
      }
      break;
    }
    case ASTNodeDataType::int_t: {
      if (node.is_type<language::integer>()) {
        if (std::stoi(node.string()) == 0) {
          break;
        }
      }
      [[fallthrough]];
    }
    default: {
      this->_checkIsNaturalTypeConversion(node, data_type, target_data_type);
    }
    }
  } else if (target_data_type == ASTNodeDataType::matrix_t) {
    switch (data_type) {
    case ASTNodeDataType::list_t: {
      const auto& content_type_list = data_type.contentTypeList();
      if (content_type_list.size() != (target_data_type.nbRows() * target_data_type.nbColumns())) {
        std::ostringstream os;
        os << "incompatible dimensions in affectation: expecting "
           << target_data_type.nbRows() * target_data_type.nbColumns() << ", but provided " << content_type_list.size();
        throw ParseError(os.str(), std::vector{node.begin()});
      }

      Assert(content_type_list.size() == node.children.size());
      for (size_t i = 0; i < content_type_list.size(); ++i) {
        const auto& child_type = *content_type_list[i];
        const auto& child_node = *node.children[i];
        Assert(child_type == child_node.m_data_type);
        this->_checkIsNaturalExpressionConversion(child_node, child_type,
                                                  ASTNodeDataType::build<ASTNodeDataType::double_t>());
      }

      break;
    }
    case ASTNodeDataType::matrix_t: {
      if ((data_type.nbRows() != target_data_type.nbRows()) or
          (data_type.nbColumns() != target_data_type.nbColumns())) {
        std::ostringstream error_message;
        error_message << "invalid implicit conversion: ";
        error_message << rang::fgB::red << dataTypeName(data_type) << " -> " << dataTypeName(target_data_type)
                      << rang::fg::reset;
        throw ParseError(error_message.str(), std::vector{node.begin()});
      }
      break;
    }
    case ASTNodeDataType::int_t: {
      if (node.is_type<language::integer>()) {
        if (std::stoi(node.string()) == 0) {
          break;
        }
      }
      [[fallthrough]];
    }
    default: {
      this->_checkIsNaturalTypeConversion(node, data_type, target_data_type);
    }
    }
  } else if (target_data_type == ASTNodeDataType::tuple_t) {
    const ASTNodeDataType& target_content_type = target_data_type.contentType();
    if (node.m_data_type == ASTNodeDataType::tuple_t) {
      this->_checkIsNaturalExpressionConversion(node, data_type.contentType(), target_content_type);
    } else if (node.m_data_type == ASTNodeDataType::list_t) {
      for (const auto& child : node.children) {
        ASTNodeNaturalConversionChecker<AllowRToR1Conversion>(*child, target_data_type.contentType());
      }
    } else {
      this->_checkIsNaturalExpressionConversion(node, data_type, target_content_type);
    }
  } else {
    this->_checkIsNaturalTypeConversion(node, data_type, target_data_type);
  }
}

template <typename RToR1Conversion>
ASTNodeNaturalConversionChecker<RToR1Conversion>::ASTNodeNaturalConversionChecker(
  const ASTNode& data_node,
  const ASTNodeDataType& target_data_type)
{
  this->_checkIsNaturalExpressionConversion(data_node, data_node.m_data_type, target_data_type);
}

template <typename RToR1Conversion>
ASTNodeNaturalConversionChecker<RToR1Conversion>::ASTNodeNaturalConversionChecker(
  const ASTNodeSubDataType& data_node_sub_data_type,
  const ASTNodeDataType& target_data_type)
{
  this->_checkIsNaturalExpressionConversion(data_node_sub_data_type.m_parent_node, data_node_sub_data_type.m_data_type,
                                            target_data_type);
}

template class ASTNodeNaturalConversionChecker<AllowRToR1Conversion>;
template class ASTNodeNaturalConversionChecker<DisallowRToR1Conversion>;