diff --git a/src/language/utils/CMakeLists.txt b/src/language/utils/CMakeLists.txt index 216a6be6a15971fe93ca2f4341f720e8b6e5cde3..a86bd52d4a89b6102f262e6de7caae4355b99911 100644 --- a/src/language/utils/CMakeLists.txt +++ b/src/language/utils/CMakeLists.txt @@ -23,8 +23,9 @@ add_library(PugsLanguageUtils BuiltinFunctionEmbedderUtils.cpp DataVariant.cpp EmbeddedData.cpp - EmbeddedIDiscreteFunctionOperators.cpp EmbeddedIDiscreteFunctionMathFunctions.cpp + EmbeddedIDiscreteFunctionOperators.cpp + EmbeddedIDiscreteFunctionUtils.cpp FunctionSymbolId.cpp IncDecOperatorRegisterForN.cpp IncDecOperatorRegisterForR.cpp diff --git a/src/language/utils/EmbeddedIDiscreteFunctionUtils.cpp b/src/language/utils/EmbeddedIDiscreteFunctionUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..286988b9b2849fb34dc84948be05414ebbad6d5b --- /dev/null +++ b/src/language/utils/EmbeddedIDiscreteFunctionUtils.cpp @@ -0,0 +1,27 @@ +#include <language/utils/EmbeddedIDiscreteFunctionUtils.hpp> + +#include <utils/Exceptions.hpp> + +bool +EmbeddedIDiscreteFunctionUtils::isSameDiscretization(const IDiscreteFunction& f, const IDiscreteFunction& g) +{ + if ((f.dataType() == g.dataType()) and (f.descriptor().type() == g.descriptor().type())) { + switch (f.dataType()) { + case ASTNodeDataType::double_t: { + return true; + } + case ASTNodeDataType::vector_t: { + return f.dataType().dimension() == g.dataType().dimension(); + } + case ASTNodeDataType::matrix_t: { + return (f.dataType().numberOfRows() == g.dataType().numberOfRows()) and + (f.dataType().numberOfColumns() == g.dataType().numberOfColumns()); + } + default: { + throw UnexpectedError("invalid data type " + getOperandTypeName(f)); + } + } + } else { + return false; + } +} diff --git a/src/language/utils/EmbeddedIDiscreteFunctionUtils.hpp b/src/language/utils/EmbeddedIDiscreteFunctionUtils.hpp index e07a1a14d9637196eda8b09b7b63fbb77fe544ac..42150443939e990e93c17996a62e2b7630727f55 100644 --- a/src/language/utils/EmbeddedIDiscreteFunctionUtils.hpp +++ b/src/language/utils/EmbeddedIDiscreteFunctionUtils.hpp @@ -1,12 +1,9 @@ #ifndef EMBEDDED_I_DISCRETE_FUNCTION_UTILS_HPP #define EMBEDDED_I_DISCRETE_FUNCTION_UTILS_HPP +#include <language/utils/ASTNodeDataType.hpp> #include <scheme/IDiscreteFunction.hpp> #include <scheme/IDiscreteFunctionDescriptor.hpp> -#include <utils/Exceptions.hpp> - -#include <sstream> -#include <string> struct EmbeddedIDiscreteFunctionUtils { @@ -15,7 +12,7 @@ struct EmbeddedIDiscreteFunctionUtils getOperandTypeName(const T& t) { if constexpr (is_shared_ptr_v<T>) { - Assert(t.use_count() > 0); + Assert(t.use_count() > 0, "dangling shared_ptr"); return getOperandTypeName(*t); } else if constexpr (std::is_base_of_v<IDiscreteFunction, std::decay_t<T>>) { return "Vh(" + name(t.descriptor().type()) + ':' + dataTypeName(t.dataType()) + ')'; @@ -24,30 +21,7 @@ struct EmbeddedIDiscreteFunctionUtils } } - PUGS_INLINE - static bool - isSameDiscretization(const IDiscreteFunction& f, const IDiscreteFunction& g) - { - if ((f.dataType() == g.dataType()) and (f.descriptor().type() == g.descriptor().type())) { - switch (f.dataType()) { - case ASTNodeDataType::double_t: { - return true; - } - case ASTNodeDataType::vector_t: { - return f.dataType().dimension() == g.dataType().dimension(); - } - case ASTNodeDataType::matrix_t: { - return (f.dataType().numberOfRows() == g.dataType().numberOfRows()) and - (f.dataType().numberOfColumns() == g.dataType().numberOfColumns()); - } - default: { - throw UnexpectedError("invalid data type " + getOperandTypeName(f)); - } - } - } else { - return false; - } - } + static bool isSameDiscretization(const IDiscreteFunction& f, const IDiscreteFunction& g); static PUGS_INLINE bool isSameDiscretization(const std::shared_ptr<const IDiscreteFunction>& f, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 62a70aad353dbb9d9deb7fcc7688865ab1987024..c70beb399ed4890c85e552917b6252f23c4d1595 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -67,6 +67,7 @@ add_executable (unit_tests test_DoWhileProcessor.cpp test_EigenvalueSolver.cpp test_EmbeddedData.cpp + test_EmbeddedIDiscreteFunctionUtils.cpp test_EscapedString.cpp test_Exceptions.cpp test_ExecutionPolicy.cpp diff --git a/tests/test_EmbeddedIDiscreteFunctionUtils.cpp b/tests/test_EmbeddedIDiscreteFunctionUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72ce3374369ace3e93551c94094000d117cb336e --- /dev/null +++ b/tests/test_EmbeddedIDiscreteFunctionUtils.cpp @@ -0,0 +1,125 @@ +#include <catch2/catch_test_macros.hpp> +#include <catch2/matchers/catch_matchers_all.hpp> + +#include <language/utils/EmbeddedIDiscreteFunctionUtils.hpp> +#include <scheme/DiscreteFunctionP0.hpp> +#include <scheme/DiscreteFunctionP0Vector.hpp> + +#include <MeshDataBaseForTests.hpp> + +// clazy:excludeall=non-pod-global-static + +TEST_CASE("EmbeddedIDiscreteFunctionUtils", "[language]") +{ + using R1 = TinyVector<1, double>; + using R2 = TinyVector<2, double>; + using R3 = TinyVector<3, double>; + + using R1x1 = TinyMatrix<1, double>; + using R2x2 = TinyMatrix<2, double>; + using R3x3 = TinyMatrix<3, double>; + + SECTION("operand type name") + { + SECTION("basic types") + { + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(double{1}) == "R"); + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(std::make_shared<double>(1)) == "R"); + } + + SECTION("discrete P0 function") + { + std::shared_ptr mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D(); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, double>{mesh_1d}) == "Vh(P0:R)"); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R1>{mesh_1d}) == "Vh(P0:R^1)"); + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R2>{mesh_1d}) == "Vh(P0:R^2)"); + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R3>{mesh_1d}) == "Vh(P0:R^3)"); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R1x1>{mesh_1d}) == + "Vh(P0:R^1x1)"); + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R2x2>{mesh_1d}) == + "Vh(P0:R^2x2)"); + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0<1, R3x3>{mesh_1d}) == + "Vh(P0:R^3x3)"); + } + + SECTION("discrete P0Vector function") + { + std::shared_ptr mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D(); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(DiscreteFunctionP0Vector<1, double>{mesh_1d, 2}) == + "Vh(P0Vector:R)"); + } + } + + SECTION("check if is same discretization") + { + SECTION("from shared_ptr") + { + std::shared_ptr mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D(); + + REQUIRE( + EmbeddedIDiscreteFunctionUtils::isSameDiscretization(std::make_shared<DiscreteFunctionP0<1, double>>(mesh_1d), + std::make_shared<DiscreteFunctionP0<1, double>>(mesh_1d))); + + REQUIRE(not EmbeddedIDiscreteFunctionUtils:: + isSameDiscretization(std::make_shared<DiscreteFunctionP0<1, double>>(mesh_1d), + std::make_shared<DiscreteFunctionP0Vector<1, double>>(mesh_1d, 1))); + } + + SECTION("from value") + { + std::shared_ptr mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D(); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, double>{mesh_1d}, + DiscreteFunctionP0<1, double>{mesh_1d})); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R1>{mesh_1d}, + DiscreteFunctionP0<1, R1>{mesh_1d})); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R2>{mesh_1d}, + DiscreteFunctionP0<1, R2>{mesh_1d})); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R3>{mesh_1d}, + DiscreteFunctionP0<1, R3>{mesh_1d})); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R1x1>{mesh_1d}, + DiscreteFunctionP0<1, R1x1>{mesh_1d})); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R2x2>{mesh_1d}, + DiscreteFunctionP0<1, R2x2>{mesh_1d})); + + REQUIRE(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R3x3>{mesh_1d}, + DiscreteFunctionP0<1, R3x3>{mesh_1d})); + + REQUIRE(not EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, double>{mesh_1d}, + DiscreteFunctionP0<1, R1>{mesh_1d})); + + REQUIRE(not EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R2>{mesh_1d}, + DiscreteFunctionP0<1, R2x2>{mesh_1d})); + + REQUIRE(not EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, R3x3>{mesh_1d}, + DiscreteFunctionP0<1, R2x2>{mesh_1d})); + } + + SECTION("invalid data type") + { + std::shared_ptr mesh_1d = MeshDataBaseForTests::get().cartesianMesh1D(); + + REQUIRE_THROWS_WITH(EmbeddedIDiscreteFunctionUtils::isSameDiscretization(DiscreteFunctionP0<1, int64_t>{mesh_1d}, + DiscreteFunctionP0<1, int64_t>{mesh_1d}), + "unexpected error: invalid data type Vh(P0:Z)"); + } + } + +#ifndef NDEBUG + SECTION("errors") + { + REQUIRE_THROWS_WITH(EmbeddedIDiscreteFunctionUtils::getOperandTypeName(std::shared_ptr<double>()), + "dangling shared_ptr"); + } + +#endif // NDEBUG +}