diff --git a/tests/test_ASTNodeDataType.cpp b/tests/test_ASTNodeDataType.cpp index e293cdd966bd141876a06678e1d99a73348a31d3..4e4e5759b68c3f94a288bc785f313711b41ec7d0 100644 --- a/tests/test_ASTNodeDataType.cpp +++ b/tests/test_ASTNodeDataType.cpp @@ -9,6 +9,7 @@ namespace language struct integer; struct real; struct vector_type; +struct matrix_type; } // namespace language // clazy:excludeall=non-pod-global-static @@ -62,6 +63,14 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2)) == "R^2"); REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3)) == "R^3"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::vector_t>(7)) == "R^7"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1)) == "R^1x1"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2)) == "R^2x2"); + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)) == "R^3x3"); + + REQUIRE(dataTypeName(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(7, 3)) == "R^7x3"); + REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{}) == "void"); REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{bool_dt}) == "B"); REQUIRE(dataTypeName(std::vector<ASTNodeDataType>{bool_dt, unsigned_int_dt}) == "(B,N)"); @@ -140,6 +149,99 @@ TEST_CASE("ASTNodeDataType", "[language]") type_node->children[1]->set_type<language::real>(); REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "unexpected non integer constant dimension"); } + + SECTION("bad dimension value") + { + type_node->children[1]->source = "0"; + REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + + type_node->children[1]->source = "4"; + REQUIRE_THROWS_WITH(getVectorDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + } + } + + SECTION("getMatrixDataType") + { + std::unique_ptr type_node = std::make_unique<ASTNode>(); + type_node->set_type<language::matrix_type>(); + + type_node->emplace_back(std::make_unique<ASTNode>()); + + { + { + std::unique_ptr dimension0_node = std::make_unique<ASTNode>(); + dimension0_node->set_type<language::integer>(); + dimension0_node->source = "3"; + auto& source0 = dimension0_node->source; + dimension0_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source0[0]}; + dimension0_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source0[source0.size()]}; + type_node->emplace_back(std::move(dimension0_node)); + } + { + std::unique_ptr dimension1_node = std::make_unique<ASTNode>(); + dimension1_node->set_type<language::integer>(); + dimension1_node->source = "3"; + auto& source1 = dimension1_node->source; + dimension1_node->m_begin = TAO_PEGTL_NAMESPACE::internal::iterator{&source1[0]}; + dimension1_node->m_end = TAO_PEGTL_NAMESPACE::internal::iterator{&source1[source1.size()]}; + type_node->emplace_back(std::move(dimension1_node)); + } + } + + SECTION("good node") + { + REQUIRE(getMatrixDataType(*type_node) == ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3)); + REQUIRE(getMatrixDataType(*type_node).nbRows() == 3); + REQUIRE(getMatrixDataType(*type_node).nbColumns() == 3); + } + + SECTION("bad node type") + { + type_node->set_type<language::integer>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad children size 1") + { + type_node->children.clear(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad children size 1") + { + type_node->children.emplace_back(std::unique_ptr<ASTNode>()); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected node type"); + } + + SECTION("bad dimension 0 type") + { + type_node->children[1]->set_type<language::real>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected non integer constant dimension"); + } + + SECTION("bad dimension 1 type") + { + type_node->children[2]->set_type<language::real>(); + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "unexpected non integer constant dimension"); + } + + SECTION("bad nb rows value") + { + type_node->children[1]->source = "0"; + type_node->children[2]->source = "0"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + + type_node->children[1]->source = "4"; + type_node->children[2]->source = "4"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "invalid dimension (must be 1, 2 or 3)"); + } + + SECTION("none square matrices") + { + type_node->children[1]->source = "1"; + type_node->children[2]->source = "2"; + REQUIRE_THROWS_WITH(getMatrixDataType(*type_node), "only square matrices are supported"); + } } SECTION("isNaturalConversion") @@ -153,6 +255,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(string_dt, bool_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(bool_dt), bool_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), bool_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), bool_dt)); } SECTION("-> N") @@ -165,6 +268,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE( not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(unsigned_int_dt), unsigned_int_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), unsigned_int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), unsigned_int_dt)); } SECTION("-> Z") @@ -176,6 +280,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(string_dt, int_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(int_dt), int_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), int_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), int_dt)); } SECTION("-> R") @@ -187,6 +292,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(string_dt, double_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), double_dt)); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), double_dt)); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), double_dt)); } SECTION("-> string") @@ -198,6 +304,7 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(isNaturalConversion(string_dt, string_dt)); REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(string_dt), string_dt)); REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), string_dt)); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), string_dt)); } SECTION("-> tuple") @@ -227,6 +334,21 @@ TEST_CASE("ASTNodeDataType", "[language]") REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(4))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::vector_t>(9))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), ASTNodeDataType::build<ASTNodeDataType::vector_t>(1))); REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), @@ -249,6 +371,53 @@ TEST_CASE("ASTNodeDataType", "[language]") ASTNodeDataType::build<ASTNodeDataType::vector_t>(3))); } + SECTION("-> matrix") + { + REQUIRE(not isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(unsigned_int_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(int_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(double_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(string_dt, ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::tuple_t>(double_dt), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(4), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::vector_t>(9), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2))); + + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(1, 1), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + REQUIRE(not isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(2, 2), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 2))); + REQUIRE(isNaturalConversion(ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3), + ASTNodeDataType::build<ASTNodeDataType::matrix_t>(3, 3))); + } + SECTION("-> type_id") { REQUIRE(not isNaturalConversion(bool_dt, ASTNodeDataType::build<ASTNodeDataType::type_id_t>("foo")));